diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index df431f1a..7f511c5f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -19,7 +19,7 @@ jobs: testing: strategy: matrix: - go-version: [1.12.x,1.13.x,1.14.x,1.15.x,1.16.x,1.17.x,1.18.x,1.19.x] + go-version: [1.12.x,1.13.x,1.14.x,1.15.x,1.16.x,1.17.x,1.18.x,1.19.x,1.20.x,1.21.x] platform: [ubuntu-latest] runs-on: ${{ matrix.platform }} steps: diff --git a/agent.go b/agent.go index 17884ad4..6428ff09 100644 --- a/agent.go +++ b/agent.go @@ -133,7 +133,7 @@ func (a *Agent) RegisterCommandHandler(f CommandHandler) { a.commandHandlers = append(a.commandHandlers, f) } -func (a *Agent) GetDynamicRegistryInfo() *registrySnapInfoStorage { +func (a *Agent) GetDynamicRegistryInfo() *RegistrySnapInfoStorage { return a.configurer.getRegistryInfo() } @@ -150,7 +150,7 @@ func (a *Agent) RuntimeDir() string { return a.runtimedir } -// get Agent server +// GetAgentServer get Agent server func (a *Agent) GetAgentServer() motan.Server { return a.agentServer } @@ -300,7 +300,7 @@ func (a *Agent) initStatus() { func (a *Agent) saveStatus() { statSnapFile := a.runtimedir + string(filepath.Separator) + defaultStatusSnap - err := ioutil.WriteFile(statSnapFile, []byte(strconv.Itoa(int(http.StatusOK))), 0644) + err := ioutil.WriteFile(statSnapFile, []byte(strconv.Itoa(http.StatusOK)), 0644) if err != nil { vlog.Errorln("Save status error: " + err.Error()) return @@ -351,6 +351,14 @@ func (a *Agent) initParam() { initLog(logDir, section) registerSwitchers(a.Context) + processPoolSize := 0 + if section != nil && section["processPoolSize"] != nil { + processPoolSize = section["processPoolSize"].(int) + } + if processPoolSize > 0 { + mserver.SetProcessPoolSize(processPoolSize) + } + port := *motan.Port if port == 0 && section != nil && section["port"] != nil { port = section["port"].(int) @@ -474,7 +482,7 @@ func (a *Agent) reloadClusters(ctx *motan.Context) { serviceItemKeep := make(map[string]bool) clusterMap := make(map[interface{}]interface{}) serviceMap := make(map[interface{}]interface{}) - var allRefersURLs = []*motan.URL{} + var allRefersURLs []*motan.URL if a.configurer != nil { //keep all dynamic refers for _, url := range a.configurer.subscribeNodes { @@ -490,7 +498,7 @@ func (a *Agent) reloadClusters(ctx *motan.Context) { } service := url.Path - mapKey := getClusterKey(url.Group, url.GetStringParamsWithDefault(motan.VersionKey, "0.1"), url.Protocol, url.Path) + mapKey := getClusterKey(url.Group, url.GetStringParamsWithDefault(motan.VersionKey, motan.DefaultReferVersion), url.Protocol, url.Path) // find exists old serviceMap var serviceMapValue serviceMapItem @@ -589,7 +597,7 @@ func (a *Agent) initCluster(url *motan.URL) { } a.serviceMap.UnsafeStore(url.Path, serviceMapItemArr) }) - mapKey := getClusterKey(url.Group, url.GetStringParamsWithDefault(motan.VersionKey, "0.1"), url.Protocol, url.Path) + mapKey := getClusterKey(url.Group, url.GetStringParamsWithDefault(motan.VersionKey, motan.DefaultReferVersion), url.Protocol, url.Path) a.clsLock.Lock() // Mutually exclusive with the reloadClusters method defer a.clsLock.Unlock() a.clusterMap.Store(mapKey, c) @@ -748,7 +756,9 @@ func (a *agentMessageHandler) httpCall(request motan.Request, ck string, httpClu if err != nil { return getDefaultResponse(request.GetRequestID(), "do http request failed : "+err.Error()) } - res = &motan.MotanResponse{RequestID: request.GetRequestID()} + httpMotanResp := mhttp.AcquireHttpMotanResponse() + httpMotanResp.RequestID = request.GetRequestID() + res = httpMotanResp mhttp.FasthttpResponseToMotanResponse(res, httpResponse) return res } @@ -794,61 +804,38 @@ func (a *agentMessageHandler) Call(request motan.Request) (res motan.Response) { } return res } -func (a *agentMessageHandler) matchRule(typ, cond, key string, data []serviceMapItem, f func(u *motan.URL) string) (foundClusters []serviceMapItem, err error) { - if cond == "" { - err = fmt.Errorf("empty %s is not supported", typ) + +func (a *agentMessageHandler) findCluster(request motan.Request) (c *cluster.MotanCluster, key string, err error) { + service := request.GetServiceName() + if service == "" { + err = fmt.Errorf("empty service is not supported. service: %s", service) return } - for _, item := range data { - if f(item.url) == cond { - foundClusters = append(foundClusters, item) - } + serviceItemArrI, exists := a.agent.serviceMap.Load(service) + if !exists { + err = fmt.Errorf("cluster not found. service: %s", service) + return } - if len(foundClusters) == 0 { - err = fmt.Errorf("cluster not found. cluster:%s", key) + clusters := serviceItemArrI.([]serviceMapItem) + if len(clusters) == 1 { + //TODO: add strict mode to avoid incorrect group call + c = clusters[0].cluster return } - return -} -func (a *agentMessageHandler) findCluster(request motan.Request) (c *cluster.MotanCluster, key string, err error) { - service := request.GetServiceName() group := request.GetAttachment(mpro.MGroup) - version := request.GetAttachment(mpro.MVersion) - protocol := request.GetAttachment(mpro.MProxyProtocol) - reqInfo := fmt.Sprintf("request information: {service: %s, group: %s, protocol: %s, version: %s}", - service, group, protocol, version) - serviceItemArrI, exists := a.agent.serviceMap.Load(service) - if !exists { - err = fmt.Errorf("cluster not found. cluster:%s, %s", service, reqInfo) + if group == "" { + err = fmt.Errorf("multiple clusters are matched with service: %s, but the group is empty", service) return } - search := []struct { - tip string - cond string - condFn func(u *motan.URL) string - }{ - {"service", service, func(u *motan.URL) string { return u.Path }}, - {"group", group, func(u *motan.URL) string { return u.Group }}, - {"protocol", protocol, func(u *motan.URL) string { return u.Protocol }}, - {"version", version, func(u *motan.URL) string { return u.GetParam(motan.VersionKey, "") }}, - } - foundClusters := serviceItemArrI.([]serviceMapItem) - for i, rule := range search { - if i == 0 { - key = rule.cond - } else { - key += "_" + rule.cond - } - foundClusters, err = a.matchRule(rule.tip, rule.cond, key, foundClusters, rule.condFn) - if err != nil { - return - } - if len(foundClusters) == 1 { - c = foundClusters[0].cluster + version := request.GetAttachment(mpro.MVersion) + protocol := request.GetAttachment(mpro.MProxyProtocol) + for _, j := range clusters { + if j.url.IsMatch(service, group, protocol, version) { + c = j.cluster return } } - err = fmt.Errorf("less condition to select cluster, maybe this service belongs to multiple group, protocol, version; cluster: %s, %s", key, reqInfo) + err = fmt.Errorf("no cluster matches the request; info: {service: %s, group: %s, protocol: %s, version: %s}", service, group, protocol, version) return } @@ -1145,7 +1132,7 @@ func (a *Agent) startMServer() { continue } a.mport = port - managementListener = motan.TCPKeepAliveListener{listener.(*net.TCPListener)} + managementListener = motan.TCPKeepAliveListener{TCPListener: listener.(*net.TCPListener)} break } if managementListener == nil { @@ -1158,7 +1145,7 @@ func (a *Agent) startMServer() { vlog.Infof("listen manage port %d failed:%s", a.mport, err.Error()) return } - managementListener = motan.TCPKeepAliveListener{listener.(*net.TCPListener)} + managementListener = motan.TCPKeepAliveListener{TCPListener: listener.(*net.TCPListener)} } vlog.Infof("start listen manage for address: %s", managementListener.Addr().String()) diff --git a/agent_test.go b/agent_test.go index 55b8929d..3942d5ad 100644 --- a/agent_test.go +++ b/agent_test.go @@ -11,6 +11,7 @@ import ( vlog "github.com/weibocom/motan-go/log" "github.com/weibocom/motan-go/registry" "github.com/weibocom/motan-go/serialize" + "github.com/weibocom/motan-go/server" _ "github.com/weibocom/motan-go/server" _ "golang.org/x/net/context" "io/ioutil" @@ -63,6 +64,7 @@ motan-agent: log_dir: "stdout" snapshot_dir: "./snapshot" application: "testing" + processPoolSize: 100 motan-registry: direct: @@ -87,6 +89,7 @@ motan-refer: resp := c1.BaseCall(req, nil) assert.Nil(t, resp.GetException()) assert.Equal(t, "Hello jack from motan server", resp.GetValue()) + assert.Equal(t, 100, server.GetProcessPoolSize()) } func Test_unixClientCall2(t *testing.T) { t.Parallel() @@ -387,7 +390,7 @@ func TestAgent_InitCall(t *testing.T) { } //test init cluster with one path and one groups in clusterMap - temp := agent.clusterMap.LoadOrNil(getClusterKey("test1", "0.1", "", "")) + temp := agent.clusterMap.LoadOrNil(getClusterKey("test1", "1.0", "", "")) assert.NotNil(t, temp, "init cluster with one path and two groups in clusterMap fail") //test agentHandler call with group @@ -413,15 +416,18 @@ func TestAgent_InitCall(t *testing.T) { version string except string }{ + // only input service,and there is only one cluster,findCluster would return successfully {"test0", "", "", "", "No refers for request"}, - {"test-1", "111", "222", "333", "cluster not found. cluster:test-1"}, - {"test3", "", "", "", "empty group is not supported"}, + {"test0", "g0", "", "", "No refers for request"}, + {"test0", "g0", "http", "", "No refers for request"}, + {"test0", "g0", "", "1.3", "No refers for request"}, + {"test-1", "111", "222", "333", "cluster not found"}, {"test", "g2", "", "", "No refers for request"}, - {"test", "g1", "", "", "empty protocol is not supported"}, {"test", "g1", "motan2", "", "No refers for request"}, - {"test", "g1", "motan", "", "empty version is not supported"}, {"test", "g1", "http", "1.3", "No refers for request"}, - {"test", "g1", "http", "1.2", "less condition to select cluster"}, + {"test", "b", "c", "d", "no cluster matches the request"}, + // one service matches multiple clusters, without passing group + {"test", "", "c", "d", "multiple clusters are matched with service"}, } { request.ServiceName = v.service request.SetAttachment(mpro.MGroup, v.group) @@ -479,10 +485,9 @@ func TestAgent_InitCall(t *testing.T) { version string except string }{ - {"test3", "111", "222", "333", "cluster not found. cluster:test3"}, - {"test4", "", "", "", "empty group is not supported"}, + {"test3", "111", "222", "333", "cluster not found. service: test3"}, {"test5", "", "", "", "No refers for request"}, - {"helloService2", "", "", "", "cluster not found. cluster:helloService2"}, + {"helloService2", "", "", "", "cluster not found. service: helloService2"}, } { request = newRequest(v.service, "") request.SetAttachment(mpro.MGroup, v.group) @@ -633,7 +638,7 @@ motan-service: c1.Initialize() var reply []byte req := c1.BuildRequestWithGroup("helloService", "/unixclient", []interface{}{}, "hello") - req.SetAttachment("HTTP_HOST", "test.com") + req.SetAttachment("http_Host", "test.com") resp := c1.BaseCall(req, &reply) assert.Nil(t, resp.GetException()) assert.Equal(t, "okay", string(reply)) diff --git a/cluster/command.go b/cluster/command.go index f49ad7af..f80405f7 100644 --- a/cluster/command.go +++ b/cluster/command.go @@ -108,9 +108,11 @@ type CmdList []ClientCommand func (c CmdList) Len() int { return len(c) } + func (c CmdList) Swap(i, j int) { c[i], c[j] = c[j], c[i] } + func (c CmdList) Less(i, j int) bool { return c[i].Index < c[j].Index } @@ -149,7 +151,7 @@ func GetCommandRegistryWrapper(cluster *MotanCluster, registry motan.Registry) m mixGroups := cluster.GetURL().GetParam(motan.MixGroups, "") if mixGroups != "" { groups := strings.Split(mixGroups, ",") - command := &ClientCommand{CommandType: CMDTrafficControl, Index: 0, Version: "1.0", MergeGroups: make([]string, 0, len(groups)+1)} + command := &ClientCommand{CommandType: CMDTrafficControl, Index: 0, Version: motan.DefaultReferVersion, MergeGroups: make([]string, 0, len(groups)+1)} ownGroup := cluster.GetURL().Group command.MergeGroups = append(command.MergeGroups, ownGroup) for _, group := range groups { diff --git a/cluster/motanCluster.go b/cluster/motanCluster.go index b0180ac4..4bae0690 100644 --- a/cluster/motanCluster.go +++ b/cluster/motanCluster.go @@ -60,6 +60,7 @@ func (m *MotanCluster) GetURL() *motan.URL { func (m *MotanCluster) SetURL(url *motan.URL) { m.url = url } + func (m *MotanCluster) Call(request motan.Request) (res motan.Response) { defer motan.HandlePanic(func() { res = motan.BuildExceptionResponse(request.GetRequestID(), &motan.Exception{ErrCode: 500, ErrMsg: "cluster call panic", ErrType: motan.ServiceException}) @@ -71,6 +72,7 @@ func (m *MotanCluster) Call(request motan.Request) (res motan.Response) { vlog.Infoln("cluster:" + m.GetIdentity() + "is not available!") return motan.BuildExceptionResponse(request.GetRequestID(), &motan.Exception{ErrCode: 500, ErrMsg: "cluster not available, maybe caused by degrade", ErrType: motan.ServiceException}) } + func (m *MotanCluster) initCluster() bool { m.registryRefers = make(map[string][]motan.EndPoint) //ha @@ -99,15 +101,19 @@ func (m *MotanCluster) initCluster() bool { vlog.Infof("init MotanCluster %s", m.GetIdentity()) return true } + func (m *MotanCluster) SetLoadBalance(loadBalance motan.LoadBalance) { m.LoadBalance = loadBalance } + func (m *MotanCluster) SetHaStrategy(haStrategy motan.HaStrategy) { m.HaStrategy = haStrategy } + func (m *MotanCluster) GetRefers() []motan.EndPoint { return m.Refers } + func (m *MotanCluster) refresh() { newRefers := make([]motan.EndPoint, 0, 32) for _, v := range m.registryRefers { @@ -120,14 +126,17 @@ func (m *MotanCluster) refresh() { m.Refers = newRefers m.LoadBalance.OnRefresh(newRefers) } + func (m *MotanCluster) ShuffleEndpoints(endpoints []motan.EndPoint) []motan.EndPoint { rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(endpoints), func(i, j int) { endpoints[i], endpoints[j] = endpoints[j], endpoints[i] }) return endpoints } + func (m *MotanCluster) AddRegistry(registry motan.Registry) { m.Registries = append(m.Registries, registry) } + func (m *MotanCluster) Notify(registryURL *motan.URL, urls []*motan.URL) { vlog.Infof("cluster %s receive notify size %d. ", m.GetIdentity(), len(urls)) m.notifyLock.Lock() diff --git a/core/bytes.go b/core/bytes.go index 604c6d57..26dedc27 100644 --- a/core/bytes.go +++ b/core/bytes.go @@ -4,6 +4,16 @@ import ( "encoding/binary" "errors" "io" + "sync" +) + +var ( + bytesBufferPool = sync.Pool{New: func() interface{} { + return &BytesBuffer{ + temp: make([]byte, 8), + order: binary.BigEndian, + } + }} ) // BytesBuffer is a variable-sized buffer of bytes with Read and Write methods. @@ -19,14 +29,15 @@ type BytesBuffer struct { var ErrNotEnough = errors.New("BytesBuffer: not enough bytes") var ErrOverflow = errors.New("BytesBuffer: integer overflow") -// NewBytesBuffer create a empty BytesBuffer with initial size +// NewBytesBuffer create an empty BytesBuffer with initial size func NewBytesBuffer(initsize int) *BytesBuffer { return NewBytesBufferWithOrder(initsize, binary.BigEndian) } -// NewBytesBufferWithOrder create a empty BytesBuffer with initial size and byte order +// NewBytesBufferWithOrder create an empty BytesBuffer with initial size and byte order func NewBytesBufferWithOrder(initsize int, order binary.ByteOrder) *BytesBuffer { - return &BytesBuffer{buf: make([]byte, initsize), + return &BytesBuffer{ + buf: make([]byte, initsize), order: order, temp: make([]byte, 8), } @@ -78,6 +89,16 @@ func (b *BytesBuffer) WriteByte(c byte) { b.wpos++ } +// WriteString write a str string append the BytesBuffer, and the wpos will increase len(str) +func (b *BytesBuffer) WriteString(str string) { + l := len(str) + if len(b.buf) < b.wpos+l { + b.grow(l) + } + copy(b.buf[b.wpos:], str) + b.wpos += l +} + // Write write a byte array append the BytesBuffer, and the wpos will increase len(bytes) func (b *BytesBuffer) Write(bytes []byte) { l := len(bytes) @@ -117,11 +138,11 @@ func (b *BytesBuffer) WriteUint64(u uint64) { } func (b *BytesBuffer) WriteZigzag32(u uint32) int { - return b.WriteVarint(uint64((u << 1) ^ uint32((int32(u) >> 31)))) + return b.WriteVarint(uint64((u << 1) ^ uint32(int32(u)>>31))) } func (b *BytesBuffer) WriteZigzag64(u uint64) int { - return b.WriteVarint(uint64((u << 1) ^ uint64((int64(u) >> 63)))) + return b.WriteVarint((u << 1) ^ uint64(int64(u)>>63)) } func (b *BytesBuffer) WriteVarint(u uint64) int { @@ -225,10 +246,10 @@ func (b *BytesBuffer) ReadVarint() (x uint64, err error) { return 0, err } if (temp & 0x80) != 0x80 { - x |= (uint64(temp) << offset) + x |= uint64(temp) << offset return x, nil } - x |= (uint64(temp&0x7f) << offset) + x |= uint64(temp&0x7f) << offset } return 0, ErrOverflow } @@ -262,3 +283,20 @@ func (b *BytesBuffer) Remain() int { return b.wpos - b.rpos } func (b *BytesBuffer) Len() int { return b.wpos - 0 } func (b *BytesBuffer) Cap() int { return cap(b.buf) } + +// AcquireBytesBuffer create an empty BytesBuffer with initial size and byte order from bytesBufferPool +func AcquireBytesBuffer(initSize int) *BytesBuffer { + bb := bytesBufferPool.Get().(*BytesBuffer) + if bb.buf == nil { + bb.buf = make([]byte, initSize) + } + return bb +} + +// ReleaseBytesBuffer put the BytesBuffer to bytesBufferPool +func ReleaseBytesBuffer(b *BytesBuffer) { + if b != nil { + b.Reset() + bytesBufferPool.Put(b) + } +} diff --git a/core/bytes_test.go b/core/bytes_test.go index f842cf76..07b329cb 100644 --- a/core/bytes_test.go +++ b/core/bytes_test.go @@ -3,6 +3,7 @@ package core import ( "encoding/binary" "fmt" + "github.com/stretchr/testify/assert" "testing" ) @@ -238,3 +239,55 @@ func TestZigzag(t *testing.T) { } } } + +func TestBytesBuffer_WriteString_Grow(t *testing.T) { + a := BytesBuffer{} + a.WriteString("") + assert.Equal(t, 0, len(a.buf)) + a.WriteString("abc") + assert.Equal(t, "abc", string(a.Bytes())) + assert.Equal(t, 3, len(a.buf)) + a.WriteString("abc") + assert.Equal(t, "abcabc", string(a.Bytes())) + assert.Equal(t, 9, len(a.buf)) +} + +func TestBytesBuffer_WriteString_NoGrow(t *testing.T) { + a := BytesBuffer{buf: make([]byte, 4)} + a.WriteString("abc") + assert.Equal(t, "abc", string(a.Bytes())) + assert.Equal(t, 4, len(a.buf)) +} + +func TestDefaultBytesBufferPool(t *testing.T) { + // consume pool + for { + bb := bytesBufferPool.Get().(*BytesBuffer) + if bb.buf == nil { + break + } + } + // test new BytesBuffer + bb := AcquireBytesBuffer(10) + assert.Equal(t, 0, bb.Len()) + assert.Equal(t, 10, len(bb.buf)) + assert.Equal(t, 10, bb.Cap()) + + // test release and acquire + ReleaseBytesBuffer(bb) + newBb := AcquireBytesBuffer(10) + assert.NotEqual(t, nil, newBb) + assert.Equal(t, 0, bb.Len()) + assert.Equal(t, 10, len(bb.buf)) + assert.Equal(t, 10, bb.Cap()) + + // test put nil + var nilByteBuffer *BytesBuffer + // can not put nil to pool + ReleaseBytesBuffer(nilByteBuffer) + nilBb := bytesBufferPool.Get().(*BytesBuffer) + assert.NotEqual(t, nil, nilBb) + ReleaseBytesBuffer(nilBb) + notNilBb := AcquireBytesBuffer(10) + assert.NotEqual(t, nil, notNilBb) +} diff --git a/core/constants.go b/core/constants.go index 20ec7e76..a3108104 100644 --- a/core/constants.go +++ b/core/constants.go @@ -129,3 +129,7 @@ const ( EUnkonwnMsg = 1003 EConvertMsg = 1004 ) + +const ( + DefaultReferVersion = "1.0" +) diff --git a/core/map.go b/core/map.go index 4aa2bcf3..5d5502e8 100644 --- a/core/map.go +++ b/core/map.go @@ -26,6 +26,14 @@ func (m *StringMap) Store(key, value string) { m.mu.Unlock() } +func (m *StringMap) Reset() { + m.mu.Lock() + for k := range m.innerMap { + delete(m.innerMap, k) + } + m.mu.Unlock() +} + func (m *StringMap) Delete(key string) { m.mu.Lock() delete(m.innerMap, key) @@ -45,20 +53,13 @@ func (m *StringMap) LoadOrEmpty(key string) string { } // Range calls f sequentially for each key and value present in the map -// If f returns false, range stops the iteration +// If f returns false, range stops the iteration. +// +// Notice: do not delete elements in range function,because of Range loop the inner map directly. func (m *StringMap) Range(f func(k, v string) bool) { m.mu.RLock() - keys := make([]string, 0, len(m.innerMap)) - for k := range m.innerMap { - keys = append(keys, k) - } - m.mu.RUnlock() - - for _, k := range keys { - v, ok := m.Load(k) - if !ok { - continue - } + defer m.mu.RUnlock() + for k, v := range m.innerMap { if !f(k, v) { break } diff --git a/core/map_test.go b/core/map_test.go index ec1c95a5..aeddce9f 100644 --- a/core/map_test.go +++ b/core/map_test.go @@ -220,3 +220,16 @@ func BenchmarkCopyOnWriteMap_Load(b *testing.B) { } }) } + +func TestStringMap_Range(t *testing.T) { + a := NewStringMap(10) + a.Store("a", "a") + a.Store("b", "a") + a.Store("c", "a") + s := "" + a.Range(func(k, v string) bool { + s += v + return true + }) + assert.Equal(t, "aaa", s) +} diff --git a/core/motan.go b/core/motan.go index adfd9016..7b420cb5 100644 --- a/core/motan.go +++ b/core/motan.go @@ -14,7 +14,20 @@ import ( type taskHandler func() -var refreshTaskPool = make(chan taskHandler, 100) +var ( + refreshTaskPool = make(chan taskHandler, 100) + requestPool = sync.Pool{New: func() interface{} { + return &MotanRequest{ + RPCContext: &RPCContext{}, + Arguments: []interface{}{}, + } + }} + responsePool = sync.Pool{New: func() interface{} { + return &MotanResponse{ + RPCContext: &RPCContext{}, + } + }} +) func init() { go func() { @@ -28,10 +41,8 @@ func init() { } const ( - DefaultAttachmentSize = 16 - DefaultRPCContextMetaSize = 8 - - ProtocolLocal = "local" + DefaultAttachmentSize = 16 + ProtocolLocal = "local" ) var ( @@ -373,8 +384,6 @@ type RPCContext struct { AsyncCall bool Result *AsyncResult Reply interface{} - - Meta *StringMap // various time, it's owned by motan request context RequestSendTime time.Time RequestReceiveTime time.Time @@ -391,6 +400,27 @@ type RPCContext struct { RemoteAddr string // remote address } +func (c *RPCContext) Reset() { + // because there is a binding between RPCContext and request/response, + // some attributes such as RequestSendTime、RequestReceiveTime will be reset by request/response + // therefore, these attributes do not need to be reset here. + c.ExtFactory = nil + c.OriginalMessage = nil + c.Oneway = false + c.Proxy = false + c.GzipSize = 0 + c.BodySize = 0 + c.SerializeNum = 0 + c.Serialized = false + c.AsyncCall = false + c.Result = nil + c.Reply = nil + c.FinishHandlers = c.FinishHandlers[:0] + c.Tc = nil + c.IsMotanV1 = false + c.RemoteAddr = "" +} + func (c *RPCContext) AddFinishHandler(handler FinishHandler) { c.FinishHandlers = append(c.FinishHandlers, handler) } @@ -443,105 +473,125 @@ type MotanRequest struct { mu sync.Mutex } +func AcquireMotanRequest() *MotanRequest { + return requestPool.Get().(*MotanRequest) +} + +func ReleaseMotanRequest(req *MotanRequest) { + if req != nil { + req.Reset() + requestPool.Put(req) + } +} + +// Reset reset motan request +func (req *MotanRequest) Reset() { + req.Method = "" + req.RequestID = 0 + req.ServiceName = "" + req.MethodDesc = "" + req.RPCContext.Reset() + req.Attachment = nil + req.Arguments = req.Arguments[:0] +} + // GetAttachment GetAttachment -func (m *MotanRequest) GetAttachment(key string) string { - if m.Attachment == nil { +func (req *MotanRequest) GetAttachment(key string) string { + if req.Attachment == nil { return "" } - return m.Attachment.LoadOrEmpty(key) + return req.Attachment.LoadOrEmpty(key) } // SetAttachment : SetAttachment -func (m *MotanRequest) SetAttachment(key string, value string) { - m.GetAttachments().Store(key, value) +func (req *MotanRequest) SetAttachment(key string, value string) { + req.GetAttachments().Store(key, value) } // GetServiceName GetServiceName -func (m *MotanRequest) GetServiceName() string { - return m.ServiceName +func (req *MotanRequest) GetServiceName() string { + return req.ServiceName } // GetMethod GetMethod -func (m *MotanRequest) GetMethod() string { - return m.Method +func (req *MotanRequest) GetMethod() string { + return req.Method } // GetMethodDesc GetMethodDesc -func (m *MotanRequest) GetMethodDesc() string { - return m.MethodDesc +func (req *MotanRequest) GetMethodDesc() string { + return req.MethodDesc } -func (m *MotanRequest) GetArguments() []interface{} { - return m.Arguments +func (req *MotanRequest) GetArguments() []interface{} { + return req.Arguments } -func (m *MotanRequest) GetRequestID() uint64 { - return m.RequestID + +func (req *MotanRequest) GetRequestID() uint64 { + return req.RequestID } -func (m *MotanRequest) SetArguments(arguments []interface{}) { - m.Arguments = arguments +func (req *MotanRequest) SetArguments(arguments []interface{}) { + req.Arguments = arguments } -func (m *MotanRequest) GetAttachments() *StringMap { - attachment := (*StringMap)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&m.Attachment)))) +func (req *MotanRequest) GetAttachments() *StringMap { + attachment := (*StringMap)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&req.Attachment)))) if attachment != nil { return attachment } - m.mu.Lock() - defer m.mu.Unlock() - if m.Attachment == nil { + req.mu.Lock() + defer req.mu.Unlock() + if req.Attachment == nil { attachment = NewStringMap(DefaultAttachmentSize) - atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&m.Attachment)), unsafe.Pointer(attachment)) + atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&req.Attachment)), unsafe.Pointer(attachment)) } else { - attachment = m.Attachment + attachment = req.Attachment } return attachment } -func (m *MotanRequest) GetRPCContext(canCreate bool) *RPCContext { - if m.RPCContext == nil && canCreate { - m.RPCContext = &RPCContext{ - Meta: NewStringMap(DefaultRPCContextMetaSize), - } +func (req *MotanRequest) GetRPCContext(canCreate bool) *RPCContext { + if req.RPCContext == nil && canCreate { + req.RPCContext = &RPCContext{} } - return m.RPCContext + return req.RPCContext } -func (m *MotanRequest) Clone() interface{} { +func (req *MotanRequest) Clone() interface{} { newRequest := &MotanRequest{ - RequestID: m.RequestID, - ServiceName: m.ServiceName, - Method: m.Method, - MethodDesc: m.MethodDesc, - Arguments: m.Arguments, + RequestID: req.RequestID, + ServiceName: req.ServiceName, + Method: req.Method, + MethodDesc: req.MethodDesc, + Arguments: req.Arguments, } - if m.Attachment != nil { - newRequest.Attachment = m.Attachment.Copy() + if req.Attachment != nil { + newRequest.Attachment = req.Attachment.Copy() } - if m.RPCContext != nil { + if req.RPCContext != nil { newRequest.RPCContext = &RPCContext{ - ExtFactory: m.RPCContext.ExtFactory, - Oneway: m.RPCContext.Oneway, - Proxy: m.RPCContext.Proxy, - GzipSize: m.RPCContext.GzipSize, - SerializeNum: m.RPCContext.SerializeNum, - Serialized: m.RPCContext.Serialized, - AsyncCall: m.RPCContext.AsyncCall, - Result: m.RPCContext.Result, - Reply: m.RPCContext.Reply, - Meta: m.RPCContext.Meta, - RequestSendTime: m.RPCContext.RequestSendTime, - RequestReceiveTime: m.RPCContext.RequestReceiveTime, - ResponseSendTime: m.RPCContext.ResponseSendTime, - ResponseReceiveTime: m.RPCContext.ResponseReceiveTime, - FinishHandlers: m.RPCContext.FinishHandlers, - Tc: m.RPCContext.Tc, + ExtFactory: req.RPCContext.ExtFactory, + Oneway: req.RPCContext.Oneway, + Proxy: req.RPCContext.Proxy, + GzipSize: req.RPCContext.GzipSize, + SerializeNum: req.RPCContext.SerializeNum, + Serialized: req.RPCContext.Serialized, + AsyncCall: req.RPCContext.AsyncCall, + Result: req.RPCContext.Result, + Reply: req.RPCContext.Reply, + RequestSendTime: req.RPCContext.RequestSendTime, + RequestReceiveTime: req.RPCContext.RequestReceiveTime, + ResponseSendTime: req.RPCContext.ResponseSendTime, + ResponseReceiveTime: req.RPCContext.ResponseReceiveTime, + FinishHandlers: req.RPCContext.FinishHandlers, + Tc: req.RPCContext.Tc, } - if m.RPCContext.OriginalMessage != nil { - if oldMessage, ok := m.RPCContext.OriginalMessage.(Cloneable); ok { + if req.RPCContext.OriginalMessage != nil { + if oldMessage, ok := req.RPCContext.OriginalMessage.(Cloneable); ok { newRequest.RPCContext.OriginalMessage = oldMessage.Clone() } else { - newRequest.RPCContext.OriginalMessage = m.RPCContext.OriginalMessage + newRequest.RPCContext.OriginalMessage = req.RPCContext.OriginalMessage } } } @@ -550,14 +600,14 @@ func (m *MotanRequest) Clone() interface{} { // ProcessDeserializable : DeserializableValue to real params according toType // some serialization can deserialize without toType, so nil toType can be accepted in these serializations -func (m *MotanRequest) ProcessDeserializable(toTypes []interface{}) error { - if m.GetArguments() != nil && len(m.GetArguments()) == 1 { - if d, ok := m.GetArguments()[0].(*DeserializableValue); ok { +func (req *MotanRequest) ProcessDeserializable(toTypes []interface{}) error { + if req.GetArguments() != nil && len(req.GetArguments()) == 1 { + if d, ok := req.GetArguments()[0].(*DeserializableValue); ok { v, err := d.DeserializeMulti(toTypes) if err != nil { return err } - m.SetArguments(v) + req.SetArguments(v) } } return nil @@ -573,78 +623,99 @@ type MotanResponse struct { mu sync.Mutex } -func (m *MotanResponse) GetAttachment(key string) string { - if m.Attachment == nil { +func AcquireMotanResponse() *MotanResponse { + return responsePool.Get().(*MotanResponse) +} + +func ReleaseMotanResponse(m *MotanResponse) { + if m != nil { + m.Reset() + responsePool.Put(m) + } +} + +func (res *MotanResponse) Reset() { + res.RequestID = 0 + res.Value = nil + res.Exception = nil + res.ProcessTime = 0 + res.Attachment = nil + res.RPCContext.Reset() +} + +func (res *MotanResponse) GetAttachment(key string) string { + if res.Attachment == nil { return "" } - return m.Attachment.LoadOrEmpty(key) + return res.Attachment.LoadOrEmpty(key) } -func (m *MotanResponse) SetAttachment(key string, value string) { - m.GetAttachments().Store(key, value) +func (res *MotanResponse) SetAttachment(key string, value string) { + res.GetAttachments().Store(key, value) } -func (m *MotanResponse) GetValue() interface{} { - return m.Value +func (res *MotanResponse) GetValue() interface{} { + return res.Value } -func (m *MotanResponse) GetException() *Exception { - return m.Exception +func (res *MotanResponse) GetException() *Exception { + return res.Exception } -func (m *MotanResponse) GetRequestID() uint64 { - return m.RequestID +func (res *MotanResponse) GetRequestID() uint64 { + return res.RequestID } -func (m *MotanResponse) GetProcessTime() int64 { - return m.ProcessTime +func (res *MotanResponse) GetProcessTime() int64 { + return res.ProcessTime } -func (m *MotanResponse) GetAttachments() *StringMap { - attachment := (*StringMap)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&m.Attachment)))) +func (res *MotanResponse) GetAttachments() *StringMap { + attachment := (*StringMap)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&res.Attachment)))) if attachment != nil { return attachment } - m.mu.Lock() - defer m.mu.Unlock() - if m.Attachment == nil { + res.mu.Lock() + defer res.mu.Unlock() + if res.Attachment == nil { attachment = NewStringMap(DefaultAttachmentSize) - atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&m.Attachment)), unsafe.Pointer(attachment)) + atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&res.Attachment)), unsafe.Pointer(attachment)) } else { - attachment = m.Attachment + attachment = res.Attachment } return attachment } -func (m *MotanResponse) GetRPCContext(canCreate bool) *RPCContext { - if m.RPCContext == nil && canCreate { - m.RPCContext = &RPCContext{ - Meta: NewStringMap(DefaultRPCContextMetaSize), - } +func (res *MotanResponse) GetRPCContext(canCreate bool) *RPCContext { + if res.RPCContext == nil && canCreate { + res.RPCContext = &RPCContext{} } - return m.RPCContext + return res.RPCContext } -func (m *MotanResponse) SetProcessTime(time int64) { - m.ProcessTime = time +func (res *MotanResponse) SetProcessTime(time int64) { + res.ProcessTime = time } // ProcessDeserializable : same with MotanRequest -func (m *MotanResponse) ProcessDeserializable(toType interface{}) error { - if m.GetValue() != nil { - if d, ok := m.GetValue().(*DeserializableValue); ok { +func (res *MotanResponse) ProcessDeserializable(toType interface{}) error { + if res.GetValue() != nil { + if d, ok := res.GetValue().(*DeserializableValue); ok { v, err := d.Deserialize(toType) if err != nil { return err } - m.Value = v + res.Value = v } } return nil } func BuildExceptionResponse(requestid uint64, e *Exception) *MotanResponse { - return &MotanResponse{RequestID: requestid, Exception: e} + resp := AcquireMotanResponse() + resp.RequestID = requestid + resp.Exception = e + return resp } // extensions factory-func @@ -695,8 +766,8 @@ func (d *DefaultExtensionFactory) GetLB(url *URL) LoadBalance { } func (d *DefaultExtensionFactory) GetFilter(name string) Filter { - if newDefualt, ok := d.filterFactories[strings.TrimSpace(name)]; ok { - return newDefualt() + if newDefault, ok := d.filterFactories[strings.TrimSpace(name)]; ok { + return newDefault() } vlog.Errorf("filter name %s is not found in DefaultExtensionFactory!", name) return nil @@ -870,12 +941,15 @@ func (l *lastEndPointFilter) HasNext() bool { func (l *lastEndPointFilter) SetNext(nextFilter EndPointFilter) { vlog.Errorf("should not set next in lastEndPointFilter! filer:%s", nextFilter.GetName()) } + func (l *lastEndPointFilter) GetNext() EndPointFilter { return nil } + func (l *lastEndPointFilter) GetIndex() int { return 100 } + func (l *lastEndPointFilter) GetType() int32 { return EndPointFilterType } @@ -885,6 +959,7 @@ type lastClusterFilter struct{} func (l *lastClusterFilter) GetName() string { return "lastClusterFilter" } + func (l *lastClusterFilter) NewFilter(url *URL) Filter { return GetLastClusterFilter() } @@ -905,15 +980,19 @@ func (l *lastClusterFilter) Filter(haStrategy HaStrategy, loadBalance LoadBalanc func (l *lastClusterFilter) HasNext() bool { return false } + func (l *lastClusterFilter) SetNext(nextFilter ClusterFilter) { vlog.Errorf("should not set next in lastClusterFilter! filer:%s", nextFilter.GetName()) } + func (l *lastClusterFilter) GetNext() ClusterFilter { return nil } + func (l *lastClusterFilter) GetIndex() int { return 100 } + func (l *lastClusterFilter) GetType() int32 { return ClusterFilterType } @@ -931,12 +1010,15 @@ func (f *FilterEndPoint) Call(request Request) Response { } return f.Filter.Filter(f.Caller, request) } + func (f *FilterEndPoint) GetURL() *URL { return f.URL } + func (f *FilterEndPoint) SetURL(url *URL) { f.URL = url } + func (f *FilterEndPoint) GetName() string { return "FilterEndPoint" } @@ -1034,7 +1116,7 @@ func newRegistryGroupServiceCacheInfo(sr ServiceDiscoverableRegistry, group stri func (c *registryGroupServiceCacheInfo) getServices() ([]string, map[string]string) { if time.Now().Sub(c.lastUpdTime.Load().(time.Time)) >= registryGroupServiceInfoMaxCacheTime { select { - case refreshTaskPool <- taskHandler(func() { c.refreshServices() }): + case refreshTaskPool <- func() { c.refreshServices() }: default: vlog.Warningf("Task pool is full, refresh service of group [%s] delay", c.group) } diff --git a/core/test.go b/core/test.go index af3d857d..047683a8 100644 --- a/core/test.go +++ b/core/test.go @@ -18,6 +18,7 @@ type TestFilter struct { func (t *TestFilter) GetName() string { return "TestFilter" } + func (t *TestFilter) NewFilter(url *URL) Filter { //init with url in here return &TestFilter{URL: url} @@ -29,18 +30,23 @@ func (t *TestFilter) Filter(haStrategy HaStrategy, loadBalance LoadBalance, requ return t.GetNext().Filter(haStrategy, loadBalance, request) } + func (t *TestFilter) HasNext() bool { return t.next != nil } + func (t *TestFilter) SetNext(nextFilter ClusterFilter) { t.next = nextFilter } + func (t *TestFilter) GetNext() ClusterFilter { return t.next } + func (t *TestFilter) GetIndex() int { return t.Index } + func (t *TestFilter) GetType() int32 { return ClusterFilterType } @@ -54,6 +60,7 @@ type TestEndPointFilter struct { func (t *TestEndPointFilter) GetName() string { return "TestEndPointFilter" } + func (t *TestEndPointFilter) NewFilter(url *URL) Filter { //init with url in here return &TestEndPointFilter{URL: url} @@ -71,15 +78,19 @@ func (t *TestEndPointFilter) Filter(caller Caller, request Request) Response { func (t *TestEndPointFilter) HasNext() bool { return t.next != nil } + func (t *TestEndPointFilter) SetNext(nextFilter EndPointFilter) { t.next = nextFilter } + func (t *TestEndPointFilter) GetNext() EndPointFilter { return t.next } + func (t *TestEndPointFilter) GetIndex() int { return t.Index } + func (t *TestEndPointFilter) GetType() int32 { return EndPointFilterType } @@ -123,12 +134,15 @@ type TestEndPoint struct { func (t *TestEndPoint) GetURL() *URL { return t.URL } + func (t *TestEndPoint) SetURL(url *URL) { t.URL = url } + func (t *TestEndPoint) GetName() string { return "testEndPoint" } + func (t *TestEndPoint) Call(request Request) Response { fmt.Println("mock rpc request..") if t.ProcessTime != 0 { @@ -173,9 +187,11 @@ func (t *TestHaStrategy) GetName() string { func (t *TestHaStrategy) GetURL() *URL { return t.URL } + func (t *TestHaStrategy) SetURL(url *URL) { t.URL = url } + func (t *TestHaStrategy) Call(request Request, loadBalance LoadBalance) Response { fmt.Println("in testHaStrategy call") refer := loadBalance.Select(request) @@ -189,6 +205,7 @@ type TestLoadBalance struct { func (t *TestLoadBalance) OnRefresh(endpoints []EndPoint) { t.Endpoints = endpoints } + func (t *TestLoadBalance) Select(request Request) EndPoint { fmt.Println("in testLoadBalance select") endpoint := &TestEndPoint{} @@ -201,9 +218,11 @@ func (t *TestLoadBalance) Select(request Request) EndPoint { filterEndPoint.Filter = efilter1 return filterEndPoint } + func (t *TestLoadBalance) SelectArray(request Request) []EndPoint { return []EndPoint{&TestEndPoint{}} } + func (t *TestLoadBalance) SetWeight(weight string) { } @@ -217,42 +236,54 @@ type TestRegistry struct { func (t *TestRegistry) GetName() string { return "testRegistry" } + func (t *TestRegistry) Subscribe(url *URL, listener NotifyListener) { } + func (t *TestRegistry) Unsubscribe(url *URL, listener NotifyListener) { } + func (t *TestRegistry) Discover(url *URL) []*URL { return make([]*URL, 0) } + func (t *TestRegistry) Register(serverURL *URL) { } + func (t *TestRegistry) UnRegister(serverURL *URL) { } + func (t *TestRegistry) Available(serverURL *URL) { } + func (t *TestRegistry) Unavailable(serverURL *URL) { } + func (t *TestRegistry) GetRegisteredServices() []*URL { return make([]*URL, 0) } + func (t *TestRegistry) GetURL() *URL { if t.URL == nil { t.URL = &URL{} } return t.URL } + func (t *TestRegistry) SetURL(url *URL) { t.URL = url } + func (t *TestRegistry) InitRegistry() { } + func (t *TestRegistry) StartSnapshot(conf *SnapshotConf) { } diff --git a/core/url.go b/core/url.go index ec94faa6..11ac3be6 100644 --- a/core/url.go +++ b/core/url.go @@ -2,13 +2,13 @@ package core import ( "bytes" + "github.com/weibocom/motan-go/log" "sort" "strconv" "strings" + "sync" "sync/atomic" "time" - - "github.com/weibocom/motan-go/log" ) type URL struct { @@ -20,18 +20,26 @@ type URL struct { Parameters map[string]string `json:"parameters"` // cached info - address atomic.Value - identity atomic.Value + address atomic.Value + portStr atomic.Value + identity atomic.Value + hasMethodParamsCache atomic.Value // Whether it has method parameters + intParamCache sync.Map +} + +type int64Cache struct { + value int64 + isMiss bool // miss cache if true } var ( - defaultSerialize = "simple" + defaultSerialize = "simple" + defaultMethodParamsSubStr = ")." + defaultMissCache = &int64Cache{value: 0, isMiss: true} // Uniform miss cache ) -//TODO int param cache - // GetIdentity return the identity of url. identity info includes protocol, host, port, path, group -// the identity will cached, so must clear cached info after update above info by calling ClearCachedInfo() +// the identity will be cached, so must clear cached info after update above info by calling ClearCachedInfo() func (u *URL) GetIdentity() string { temp := u.identity.Load() if temp != nil && temp != "" { @@ -45,6 +53,33 @@ func (u *URL) GetIdentity() string { return idt } +// IsMatch is a tool function for comparing parameters: service, group, protocol and version +// with URL. When 'protocol' or 'version' is empty, it will be ignored +func (u *URL) IsMatch(service, group, protocol, version string) bool { + if u.Path != service { + return false + } + if group != "" && u.Group != group { + return false + } + // for motan v1 request, parameter protocol should be empty + if protocol != "" { + if u.Protocol == "motanV1Compatible" { + if protocol != "motan2" && protocol != "motan" { + return false + } + } else { + if u.Protocol != protocol { + return false + } + } + } + if version != "" && u.GetParam(VersionKey, "") != "" { + return version == u.GetParam(VersionKey, "") + } + return true +} + func (u *URL) GetIdentityWithRegistry() string { id := u.GetIdentity() registryId := u.GetParam(RegistryKey, "") @@ -54,14 +89,20 @@ func (u *URL) GetIdentityWithRegistry() string { func (u *URL) ClearCachedInfo() { u.address.Store("") u.identity.Store("") + u.portStr.Store("") + u.hasMethodParamsCache.Store("") + u.intParamCache.Range(func(key interface{}, value interface{}) bool { + u.intParamCache.Delete(key) + return true + }) } -func (u *URL) GetPositiveIntValue(key string, defaultvalue int64) int64 { - intvalue := u.GetIntValue(key, defaultvalue) - if intvalue < 1 { - return defaultvalue +func (u *URL) GetPositiveIntValue(key string, defaultValue int64) int64 { + intValue := u.GetIntValue(key, defaultValue) + if intValue < 1 { + return defaultValue } - return intvalue + return intValue } func (u *URL) GetBoolValue(key string, defaultValue bool) bool { @@ -83,39 +124,70 @@ func (u *URL) GetIntValue(key string, defaultValue int64) int64 { } func (u *URL) GetInt(key string) (int64, bool) { + if cache, ok := u.intParamCache.Load(key); ok { + if c, ok := cache.(*int64Cache); ok { // from cache + if c.isMiss { + return 0, false + } + return c.value, true + } + } + if v, ok := u.Parameters[key]; ok { intValue, err := strconv.ParseInt(v, 10, 64) if err == nil { + u.intParamCache.Store(key, &int64Cache{value: intValue, isMiss: false}) return intValue, true } } + u.intParamCache.Store(key, defaultMissCache) // set miss cache return 0, false } -func (u *URL) GetStringParamsWithDefault(key string, defaultvalue string) string { +func (u *URL) GetStringParamsWithDefault(key string, defaultValue string) string { var ret string if u.Parameters != nil { ret = u.Parameters[key] } if ret == "" { - ret = defaultvalue + ret = defaultValue } return ret } func (u *URL) GetMethodIntValue(method string, methodDesc string, key string, defaultValue int64) int64 { - mkey := method + "(" + methodDesc + ")." + key - result, b := u.GetInt(mkey) - if b { - return result + if u.hasMethodParams() { + mk := method + "(" + methodDesc + ")." + key + result, b := u.GetInt(mk) + if b { + return result + } } - result, b = u.GetInt(key) + result, b := u.GetInt(key) if b { return result } return defaultValue } +func (u *URL) hasMethodParams() bool { + v := u.hasMethodParamsCache.Load() + if v == nil || v == "" { // Check if method parameters exist + if u.Parameters != nil { + for k := range u.Parameters { + if strings.Contains(k, defaultMethodParamsSubStr) { + v = "t" + u.hasMethodParamsCache.Store("t") + return true + } + } + } + v = "f" + u.hasMethodParamsCache.Store("f") + } + return v == "t" +} + func (u *URL) GetMethodPositiveIntValue(method string, methodDesc string, key string, defaultValue int64) int64 { result := u.GetMethodIntValue(method, methodDesc, key, defaultValue) if result > 0 { @@ -146,6 +218,10 @@ func (u *URL) GetTimeDuration(key string, unit time.Duration, defaultDuration ti } func (u *URL) PutParam(key string, value string) { + u.intParamCache.Delete(key) // remove cache + if strings.Contains(key, defaultMethodParamsSubStr) { // Check if method parameter + u.hasMethodParamsCache.Store("t") + } if u.Parameters == nil { u.Parameters = make(map[string]string) } @@ -181,20 +257,20 @@ func (u *URL) ToExtInfo() string { } -func FromExtInfo(extinfo string) *URL { - defer func() { // if extinfo format not correct, just return nil URL +func FromExtInfo(extInfo string) *URL { + defer func() { // if extInfo format not correct, just return nil URL if err := recover(); err != nil { - vlog.Warningf("from ext to url fail. extinfo:%s, err:%v", extinfo, err) + vlog.Warningf("from ext to url fail. extInfo:%s, err:%v", extInfo, err) } }() - arr := strings.Split(extinfo, "?") - nodeinfos := strings.Split(arr[0], "://") - protocol := nodeinfos[0] - nodeinfos = strings.Split(nodeinfos[1], "/") - path := nodeinfos[1] - nodeinfos = strings.Split(nodeinfos[0], ":") - host := nodeinfos[0] - port, _ := strconv.ParseInt(nodeinfos[1], 10, 64) + arr := strings.Split(extInfo, "?") + nodeInfos := strings.Split(arr[0], "://") + protocol := nodeInfos[0] + nodeInfos = strings.Split(nodeInfos[1], "/") + path := nodeInfos[1] + nodeInfos = strings.Split(nodeInfos[0], ":") + host := nodeInfos[0] + port, _ := strconv.ParseInt(nodeInfos[1], 10, 64) paramsMap := make(map[string]string) params := strings.Split(arr[1], "&") @@ -212,7 +288,13 @@ func FromExtInfo(extinfo string) *URL { } func (u *URL) GetPortStr() string { - return strconv.FormatInt(int64(u.Port), 10) + temp := u.portStr.Load() + if temp != nil && temp != "" { + return temp.(string) + } + p := strconv.FormatInt(int64(u.Port), 10) + u.portStr.Store(p) + return p } func (u *URL) GetAddressStr() string { @@ -258,7 +340,8 @@ func (u *URL) CanServe(other *URL) bool { vlog.Errorf("can not serve serialization, err : s1:%s, s2:%s", u.Parameters[SerializationKey], other.Parameters[SerializationKey]) return false } - if !IsSame(u.Parameters, other.Parameters, VersionKey, "0.1") { + // compatible with old version: 0.1 + if !(IsSame(u.Parameters, other.Parameters, VersionKey, "0.1") || IsSame(u.Parameters, other.Parameters, VersionKey, DefaultReferVersion)) { vlog.Errorf("can not serve version, err : v1:%s, v2:%s", u.Parameters[VersionKey], other.Parameters[VersionKey]) return false } @@ -328,9 +411,11 @@ type filterSlice []Filter func (f filterSlice) Len() int { return len(f) } + func (f filterSlice) Swap(i, j int) { f[i], f[j] = f[j], f[i] } + func (f filterSlice) Less(i, j int) bool { // desc return f[i].GetIndex() > f[j].GetIndex() diff --git a/core/url_test.go b/core/url_test.go index 59baa269..73ca3a73 100644 --- a/core/url_test.go +++ b/core/url_test.go @@ -2,16 +2,17 @@ package core import ( "fmt" + "github.com/stretchr/testify/assert" "strconv" "strings" "testing" ) -func TestFromExtinfo(t *testing.T) { - extinfo := "triggerMemcache://10.73.32.175:22222/com.weibo.trigger.common.bean.TriggerMemcacheClient?nodeType=service&version=1.0&group=status2-core" - url := FromExtInfo(extinfo) +func TestFromExtInfo(t *testing.T) { + extInfo := "triggerMemcache://10.73.32.175:22222/com.weibo.trigger.common.bean.TriggerMemcacheClient?nodeType=service&version=1.0&group=status2-core" + url := FromExtInfo(extInfo) if url == nil { - t.Fatal("parse form extinfo fail") + t.Fatal("parse form extInfo fail") } fmt.Printf("url:%+v", url) if url.Host != "10.73.32.175" || url.Port != 22222 || url.Protocol != "triggerMemcache" || @@ -27,14 +28,14 @@ func TestFromExtinfo(t *testing.T) { !strings.Contains(ext2, "group=status2-core") || !strings.Contains(ext2, "nodeType=service") || !strings.Contains(ext2, "version=1.0") { - t.Fatalf("convert url to extinfo not correct. ext2: %s", ext2) + t.Fatalf("convert url to extInfo not correct. ext2: %s", ext2) } fmt.Println("ext2:", ext2) //invalid invalidInfo := "motan://123.23.33.32" url = FromExtInfo(invalidInfo) if url != nil { - t.Fatal("url should be nil when parse invalid extinfo") + t.Fatal("url should be nil when parse invalid extInfo") } } @@ -42,35 +43,36 @@ func TestGetInt(t *testing.T) { url := &URL{} params := make(map[string]string) url.Parameters = params - key := "keyt" + key := "keyT" method := "method1" methodDesc := "string,string" params[key] = "12" v, _ := url.GetInt(key) - intequals(12, v, t) + intEquals(12, v, t) - params[key] = "-20" + url.PutParam(key, "-20") // use PutParam set value will update the cache v, _ = url.GetInt(key) - intequals(-20, v, t) + intEquals(-20, v, t) v = url.GetPositiveIntValue(key, 8) - intequals(8, v, t) + intEquals(8, v, t) delete(params, key) + url.ClearCachedInfo() // clear cache v = url.GetMethodIntValue(method, methodDesc, key, 6) - intequals(6, v, t) + intEquals(6, v, t) - url.Parameters[method+"("+methodDesc+")."+key] = "-17" + url.PutParam(method+"("+methodDesc+")."+key, "-17") v = url.GetMethodIntValue(method, methodDesc, key, 6) - intequals(-17, v, t) + intEquals(-17, v, t) v = url.GetMethodPositiveIntValue(method, methodDesc, key, 9) - intequals(9, v, t) + intEquals(9, v, t) } -func intequals(expect int64, realvalue int64, t *testing.T) { - if realvalue != expect { - t.Fatalf("getint test fail, expect :%d, real :%d", expect, realvalue) +func intEquals(expect int64, realValue int64, t *testing.T) { + if realValue != expect { + t.Fatalf("getint test fail, expect :%d, real :%d", expect, realValue) } } @@ -99,7 +101,6 @@ func TestCopyAndMerge(t *testing.T) { if url.Parameters["key1"] != "xxx" { t.Fatalf("url merge not correct. expect v :%s, real v: %s", "xxx", url.Parameters["key1"]) } - } func TestCanServer(t *testing.T) { @@ -140,7 +141,7 @@ func TestCanServer(t *testing.T) { url1.Protocol = "" url2.Protocol = "" url1.Path = "test/path" - url2.Path = "xxxx" + url2.Path = "whatever" if url1.CanServe(url2) { t.Fatalf("url CanServe testFail url1: %+v, url2: %+v\n", url1, url2) } @@ -156,3 +157,69 @@ func TestGetPositiveIntValue(t *testing.T) { t.Errorf("get positive int fail. v:%d", v) } } + +func TestIntParamCache(t *testing.T) { + url := &URL{} + // test normal + url.PutParam(SessionTimeOutKey, "20000000") + _, ok := url.intParamCache.Load(SessionTimeOutKey) // cache will remove after new value set + assert.False(t, ok) + checkIntCache(t, url, SessionTimeOutKey, 20000000) + + url.PutParam(SessionTimeOutKey, "15") + _, ok = url.intParamCache.Load(SessionTimeOutKey) // cache will remove after new value set + assert.False(t, ok) + checkIntCache(t, url, SessionTimeOutKey, 15) + + // clear cache + url.ClearCachedInfo() + _, ok = url.intParamCache.Load(SessionTimeOutKey) // cache will remove after new value set + assert.False(t, ok) + checkIntCache(t, url, SessionTimeOutKey, 15) + + // test miss cache + _, ok = url.GetInt("notExist") + assert.False(t, ok) + nv, ok := url.intParamCache.Load("notExist") + assert.True(t, ok) + assert.NotNil(t, nv) + if ic, ok := nv.(*int64Cache); ok { + assert.True(t, ic.isMiss) + assert.Equal(t, defaultMissCache, ic) + } + assert.True(t, ok) + + // test hasMethod cache + v := url.GetMethodIntValue("method", "desc", "testKey", 10) + assert.Equal(t, int64(10), v) // default value + assert.False(t, url.hasMethodParams()) + url.PutParam("method(desc).testKey", "100") + assert.True(t, url.hasMethodParams()) + v = url.GetMethodIntValue("method", "desc", "testKey", 10) + assert.Equal(t, int64(100), v) + url.ClearCachedInfo() + // test init with method params + assert.True(t, url.hasMethodParams()) + // test init without method params + delete(url.Parameters, "method(desc).testKey") + url.ClearCachedInfo() + assert.False(t, url.hasMethodParams()) + + // test GetPortStr + url.Port = 8080 + assert.Equal(t, "", url.portStr.Load()) + assert.Equal(t, "8080", url.GetPortStr()) + assert.Equal(t, "8080", url.portStr.Load().(string)) +} + +func checkIntCache(t *testing.T, url *URL, k string, v int64) { + dv := url.GetIntValue(k, 20) + assert.Equal(t, v, dv) + cv, ok := url.intParamCache.Load(k) + assert.True(t, ok) + assert.NotNil(t, cv) + if i, ok := cv.(*int64Cache); ok { + assert.Equal(t, v, i.value) + } + assert.True(t, ok) +} diff --git a/dynamicConfig.go b/dynamicConfig.go index 98ca0538..afc10cc5 100644 --- a/dynamicConfig.go +++ b/dynamicConfig.go @@ -32,7 +32,7 @@ type DynamicConfigurer struct { agent *Agent } -type registrySnapInfoStorage struct { +type RegistrySnapInfoStorage struct { RegisterNodes []*core.URL `json:"register_nodes"` SubscribeNodes []*core.URL `json:"subscribe_nodes"` } @@ -60,7 +60,7 @@ func (c *DynamicConfigurer) doRecover() error { vlog.Warningln("Read configuration snapshot file error: " + err.Error()) return err } - registerSnapInfo := new(registrySnapInfoStorage) + registerSnapInfo := new(RegistrySnapInfoStorage) err = json.Unmarshal(bytes, registerSnapInfo) if err != nil { vlog.Errorln("Parse snapshot string error: " + err.Error()) @@ -162,8 +162,8 @@ func (c *DynamicConfigurer) saveSnapshot() { } } -func (c *DynamicConfigurer) getRegistryInfo() *registrySnapInfoStorage { - registrySnapInfo := registrySnapInfoStorage{} +func (c *DynamicConfigurer) getRegistryInfo() *RegistrySnapInfoStorage { + registrySnapInfo := RegistrySnapInfoStorage{} c.regLock.Lock() defer c.regLock.Unlock() diff --git a/endpoint/motanCommonEndpoint.go b/endpoint/motanCommonEndpoint.go index 220587a3..91bee948 100644 --- a/endpoint/motanCommonEndpoint.go +++ b/endpoint/motanCommonEndpoint.go @@ -3,6 +3,7 @@ package endpoint import ( "bufio" "errors" + "github.com/panjf2000/ants/v2" motan "github.com/weibocom/motan-go/core" vlog "github.com/weibocom/motan-go/log" mpro "github.com/weibocom/motan-go/protocol" @@ -14,6 +15,15 @@ import ( "time" ) +var ( + streamPool = sync.Pool{New: func() interface{} { + return &Stream{ + recvNotifyCh: make(chan struct{}, 1), + } + }} + handleMsgPool, _ = ants.NewPool(10000) +) + // MotanCommonEndpoint supports motan v1, v2 protocols type MotanCommonEndpoint struct { url *motan.URL @@ -33,6 +43,7 @@ type MotanCommonEndpoint struct { lazyInit bool maxContentLength int heartbeatVersion int + gzipSize int keepaliveRunning bool serialization motan.Serialization @@ -67,6 +78,7 @@ func (m *MotanCommonEndpoint) Initialize() { asyncInitConnection := m.url.GetBoolValue(motan.AsyncInitConnection, GetDefaultMotanEPAsynInit()) m.heartbeatVersion = -1 m.DefaultVersion = mpro.Version2 + m.gzipSize = int(m.url.GetIntValue(motan.GzipSizeKey, 0)) factory := func() (net.Conn, error) { address := m.url.GetAddressStr() if strings.HasPrefix(address, motan.UnixSockProtocolFlag) { @@ -107,9 +119,12 @@ func (m *MotanCommonEndpoint) GetRequestTimeout(request motan.Request) time.Dura if maxTimeout == 0 { maxTimeout = timeout * 2 } - reqTimeout, _ := strconv.ParseInt(request.GetAttachment(mpro.MTimeout), 10, 64) - if reqTimeout >= minTimeout && reqTimeout <= maxTimeout { - timeout = reqTimeout + rt := request.GetAttachment(mpro.MTimeout) + if rt != "" { + reqTimeout, _ := strconv.ParseInt(rt, 10, 64) + if reqTimeout >= minTimeout && reqTimeout <= maxTimeout { + timeout = reqTimeout + } } return time.Duration(timeout) * time.Millisecond } @@ -117,7 +132,7 @@ func (m *MotanCommonEndpoint) GetRequestTimeout(request motan.Request) time.Dura func (m *MotanCommonEndpoint) Call(request motan.Request) motan.Response { rc := request.GetRPCContext(true) rc.Proxy = m.proxy - rc.GzipSize = int(m.url.GetIntValue(motan.GzipSizeKey, 0)) + rc.GzipSize = m.gzipSize if m.channels == nil { vlog.Errorf("motanEndpoint %s error: channels is null", m.url.GetAddressStr()) @@ -285,16 +300,11 @@ func (m *MotanCommonEndpoint) keepalive() { } func (m *MotanCommonEndpoint) defaultErrMotanResponse(request motan.Request, errMsg string) motan.Response { - response := &motan.MotanResponse{ - RequestID: request.GetRequestID(), - Attachment: motan.NewStringMap(motan.DefaultAttachmentSize), - Exception: &motan.Exception{ - ErrCode: 400, - ErrMsg: errMsg, - ErrType: motan.ServiceException, - }, - } - return response + return motan.BuildExceptionResponse(request.GetRequestID(), &motan.Exception{ + ErrCode: 400, + ErrMsg: errMsg, + ErrType: motan.ServiceException, + }) } func (m *MotanCommonEndpoint) GetName() string { @@ -348,14 +358,31 @@ type Stream struct { recvNotifyCh chan struct{} deadline time.Time // for timeout rc *motan.RPCContext - isClose atomic.Value // bool - isHeartbeat bool // for heartbeat - heartbeatVersion int // for heartbeat + isHeartbeat bool // for heartbeat + heartbeatVersion int // for heartbeat + timer *time.Timer + canRelease atomic.Value // state indicates whether the stream needs to be recycled by the channel +} + +func (s *Stream) Reset() { + // try consume + select { + case <-s.recvNotifyCh: + default: + } + s.channel = nil + s.req = nil + s.res = nil + s.rc = nil } func (s *Stream) Send() (err error) { - timer := time.NewTimer(s.deadline.Sub(time.Now())) - defer timer.Stop() + if s.timer == nil { + s.timer = time.NewTimer(s.deadline.Sub(time.Now())) + } else { + s.timer.Reset(s.deadline.Sub(time.Now())) + } + defer s.timer.Stop() var bytes []byte var msg *mpro.Message @@ -386,13 +413,16 @@ func (s *Stream) Send() (err error) { } } + ready := sendReady{} if msg != nil { // encode v2 message - bytes = msg.Encode().Bytes() + msg.Encode0() + ready.message = msg + } else { + ready.v1Message = bytes } if s.rc != nil && s.rc.Tc != nil { s.rc.Tc.PutReqSpan(&motan.Span{Name: motan.Encode, Addr: s.channel.address, Time: time.Now()}) } - ready := sendReady{data: bytes} select { case s.channel.sendCh <- ready: if s.rc != nil { @@ -401,9 +431,13 @@ func (s *Stream) Send() (err error) { if s.rc.Tc != nil { s.rc.Tc.PutReqSpan(&motan.Span{Name: motan.Send, Addr: s.channel.address, Time: sendTime}) } + if s.rc.AsyncCall { + // only return send success, it can release in Stream.notify + s.canRelease.Store(true) + } } return nil - case <-timer.C: + case <-s.timer.C: return ErrSendRequestTimeout case <-s.channel.shutdownCh: return ErrChannelShutdown @@ -413,10 +447,18 @@ func (s *Stream) Send() (err error) { // Recv sync recv func (s *Stream) Recv() (motan.Response, error) { defer func() { - s.Close() + // only timeout or shutdown before channel.handleMsg, call RemoveFromChannel will be true, + // which means stream can release + if s.RemoveFromChannel() { + s.canRelease.Store(true) + } }() - timer := time.NewTimer(s.deadline.Sub(time.Now())) - defer timer.Stop() + if s.timer == nil { + s.timer = time.NewTimer(s.deadline.Sub(time.Now())) + } else { + s.timer.Reset(s.deadline.Sub(time.Now())) + } + defer s.timer.Stop() select { case <-s.recvNotifyCh: msg := s.res @@ -424,17 +466,18 @@ func (s *Stream) Recv() (motan.Response, error) { return nil, errors.New("recv err: recvMsg is nil") } return msg, nil - case <-timer.C: + case <-s.timer.C: + // stream may be referenced by V2Stream.notify, can`t release + s.canRelease.Store(false) return nil, ErrRecvRequestTimeout case <-s.channel.shutdownCh: + // stream may be referenced by V2Stream.notify, can`t release + s.canRelease.Store(false) return nil, ErrChannelShutdown } } func (s *Stream) notify(msg interface{}, t time.Time) { - defer func() { - s.Close() - }() decodeTime := time.Now() var res motan.Response var v2Msg *mpro.Message @@ -481,6 +524,10 @@ func (s *Stream) notify(msg interface{}, t time.Time) { s.rc.Tc.PutResSpan(&motan.Span{Name: motan.Convert, Addr: s.channel.address, Time: time.Now()}) } if s.rc.AsyncCall { + defer func() { + s.RemoveFromChannel() + releaseStream(s) + }() result := s.rc.Result if err != nil { result.Error = err @@ -503,70 +550,78 @@ func (s *Stream) SetDeadline(deadline time.Duration) { s.deadline = time.Now().Add(deadline) } -func (c *Channel) NewStream(req motan.Request, rc *motan.RPCContext) (*Stream, error) { +func (c *Channel) newStream(req motan.Request, rc *motan.RPCContext, deadline time.Duration) (*Stream, error) { if c.IsClosed() { return nil, ErrChannelShutdown } - s := &Stream{ - streamId: GenerateRequestID(), - channel: c, - req: req, - recvNotifyCh: make(chan struct{}, 1), - deadline: time.Now().Add(defaultRequestTimeout), // default deadline - rc: rc, + s := acquireStream() + s.streamId = GenerateRequestID() + s.channel = c + s.isHeartbeat = false + s.req = req + s.deadline = time.Now().Add(deadline) + s.rc = rc + if rc != nil && rc.AsyncCall { // release by Stream.Notify + s.canRelease.Store(false) + } else { // release by Channel self + s.canRelease.Store(true) } - s.isClose.Store(false) c.streamLock.Lock() c.streams[s.streamId] = s c.streamLock.Unlock() return s, nil } -func (c *Channel) NewHeartbeatStream(heartbeatVersion int) (*Stream, error) { +func (c *Channel) newHeartbeatStream(heartbeatVersion int) (*Stream, error) { if c.IsClosed() { return nil, ErrChannelShutdown } - s := &Stream{ - streamId: GenerateRequestID(), - channel: c, - isHeartbeat: true, - heartbeatVersion: heartbeatVersion, - recvNotifyCh: make(chan struct{}, 1), - deadline: time.Now().Add(defaultRequestTimeout), - } - s.isClose.Store(false) + s := acquireStream() + s.streamId = GenerateRequestID() + s.channel = c + s.isHeartbeat = true + s.heartbeatVersion = heartbeatVersion + s.deadline = time.Now().Add(defaultRequestTimeout) + s.canRelease.Store(true) c.heartbeatLock.Lock() c.heartbeats[s.streamId] = s c.heartbeatLock.Unlock() return s, nil } -func (s *Stream) Close() { - if !s.isClose.Load().(bool) { - if s.isHeartbeat { - s.channel.heartbeatLock.Lock() +func (s *Stream) RemoveFromChannel() bool { + var exist bool + if s.isHeartbeat { + s.channel.heartbeatLock.Lock() + if _, exist = s.channel.heartbeats[s.streamId]; exist { delete(s.channel.heartbeats, s.streamId) - s.channel.heartbeatLock.Unlock() - } else { - s.channel.streamLock.Lock() + } + s.channel.heartbeatLock.Unlock() + } else { + s.channel.streamLock.Lock() + if _, exist = s.channel.streams[s.streamId]; exist { delete(s.channel.streams, s.streamId) - s.channel.streamLock.Unlock() } - s.isClose.Store(true) + s.channel.streamLock.Unlock() } + return exist } // Call send request to the server. // // about return: exception in response will record error count, err will not. func (c *Channel) Call(req motan.Request, deadline time.Duration, rc *motan.RPCContext) (motan.Response, error) { - stream, err := c.NewStream(req, rc) + stream, err := c.newStream(req, rc, deadline) if err != nil { return nil, err } - stream.SetDeadline(deadline) - err = stream.Send() - if err != nil { + defer func() { + if rc == nil || !rc.AsyncCall { + releaseStream(stream) + } + }() + + if err = stream.Send(); err != nil { return nil, err } if rc != nil && rc.AsyncCall { @@ -576,12 +631,13 @@ func (c *Channel) Call(req motan.Request, deadline time.Duration, rc *motan.RPCC } func (c *Channel) HeartBeat(heartbeatVersion int) (motan.Response, error) { - stream, err := c.NewHeartbeatStream(heartbeatVersion) + stream, err := c.newHeartbeatStream(heartbeatVersion) if err != nil { return nil, err } - err = stream.Send() - if err != nil { + defer releaseStream(stream) + + if err = stream.Send(); err != nil { return nil, err } return stream.Recv() @@ -601,6 +657,7 @@ func (c *Channel) recv() { } func (c *Channel) recvLoop() error { + decodeBuf := make([]byte, mpro.DefaultBufferSize) for { v, err := mpro.CheckMotanVersion(c.bufRead) if err != nil { @@ -611,7 +668,7 @@ func (c *Channel) recvLoop() error { if v == mpro.Version1 { msg, t, err = mpro.ReadV1Message(c.bufRead, c.config.MaxContentLength) } else if v == mpro.Version2 { - msg, t, err = mpro.DecodeWithTime(c.bufRead, c.config.MaxContentLength) + msg, t, err = mpro.DecodeWithTime(c.bufRead, &decodeBuf, c.config.MaxContentLength) } else { vlog.Warningf("unsupported motan version! version:%d con:%s.", v, c.conn.RemoteAddr().String()) err = mpro.ErrVersion @@ -619,7 +676,10 @@ func (c *Channel) recvLoop() error { if err != nil { return err } - go c.handleMsg(msg, t) + + handleMsgPool.Submit(func() { + c.handleMsg(msg, t) + }) } } @@ -647,10 +707,12 @@ func (c *Channel) handleMsg(msg interface{}, t time.Time) { if isHeartbeat { c.heartbeatLock.Lock() stream = c.heartbeats[rid] + delete(c.heartbeats, rid) c.heartbeatLock.Unlock() } else { c.streamLock.Lock() stream = c.streams[rid] + delete(c.streams, rid) c.streamLock.Unlock() } if stream == nil { @@ -667,18 +729,24 @@ func (c *Channel) send() { for { select { case ready := <-c.sendCh: - if ready.data != nil { - c.conn.SetWriteDeadline(time.Now().Add(motan.DefaultWriteTimeout)) - sent := 0 - for sent < len(ready.data) { - n, err := c.conn.Write(ready.data[sent:]) - if err != nil { - vlog.Errorf("Failed to write channel. ep: %s, err: %s", c.address, err.Error()) - c.closeOnErr(err) - return - } - sent += n - } + // can`t reuse net.Buffers + // len and cap will be 0 after writev, 'out of range panic' will happen when reuse + var sendBuf net.Buffers + if ready.message != nil { // motan2 + sendBuf = ready.message.GetEncodedBytes() + } else { + sendBuf = [][]byte{ready.v1Message} + } + c.conn.SetWriteDeadline(time.Now().Add(motan.DefaultWriteTimeout)) + _, err := sendBuf.WriteTo(c.conn) + if ready.message != nil { + // message canRelease condition before reset + ready.message.SetCanRelease() + } + if err != nil { + vlog.Errorf("Failed to write channel. ep: %s, err: %s", c.address, err.Error()) + c.closeOnErr(err) + return } case <-c.shutdownCh: return @@ -845,3 +913,16 @@ func buildChannel(conn net.Conn, config *ChannelConfig, serialization motan.Seri return channel } + +func acquireStream() *Stream { + return streamPool.Get().(*Stream) +} + +func releaseStream(stream *Stream) { + if stream != nil { + if v, ok := stream.canRelease.Load().(bool); ok && v { + stream.Reset() + streamPool.Put(stream) + } + } +} diff --git a/endpoint/motanCommonEndpoint_test.go b/endpoint/motanCommonEndpoint_test.go index 3cd2715f..75c460c5 100644 --- a/endpoint/motanCommonEndpoint_test.go +++ b/endpoint/motanCommonEndpoint_test.go @@ -48,6 +48,8 @@ func TestV1RecordErrEmptyThreshold(t *testing.T) { ep.Call(request) assert.True(t, ep.IsAvailable()) } + + assertChanelStreamEmpty(ep, t) ep.Destroy() } @@ -79,6 +81,8 @@ func TestV1RecordErrWithErrThreshold(t *testing.T) { _ = conn.(*net.TCPConn).SetNoDelay(true) ep.channels.channels <- buildChannel(conn, ep.channels.config, ep.channels.serialization) time.Sleep(time.Second * 2) + + assertChanelStreamEmpty(ep, t) //assert.True(t, ep.IsAvailable()) ep.Destroy() } @@ -102,6 +106,8 @@ func TestMotanCommonEndpoint_SuccessCall(t *testing.T) { s, ok := v.(string) assert.True(t, ok) assert.Equal(t, s, "hello") + + assertChanelStreamEmpty(ep, t) } func TestMotanCommonEndpoint_AsyncCall(t *testing.T) { @@ -123,6 +129,8 @@ func TestMotanCommonEndpoint_AsyncCall(t *testing.T) { resp := <-request.GetRPCContext(false).Result.Done assert.Nil(t, resp.Error) assert.Equal(t, resStr, "hello") + + assertChanelStreamEmpty(ep, t) } func TestMotanCommonEndpoint_ErrorCall(t *testing.T) { @@ -147,6 +155,8 @@ func TestMotanCommonEndpoint_ErrorCall(t *testing.T) { ep.Call(request) time.Sleep(1 * time.Millisecond) assert.Equal(t, beforeNGoroutine, runtime.NumGoroutine()) + + assertChanelStreamEmpty(ep, t) ep.Destroy() } @@ -173,6 +183,8 @@ func TestMotanCommonEndpoint_RequestTimeout(t *testing.T) { ep.Call(request) time.Sleep(1 * time.Millisecond) assert.Equal(t, beforeNGoroutine, runtime.NumGoroutine()) + + assertChanelStreamEmpty(ep, t) ep.Destroy() } @@ -215,3 +227,100 @@ func TestV1AsyncInit(t *testing.T) { ep.Initialize() time.Sleep(time.Second * 5) } + +// TestMotanCommonEndpoint_AsyncCallNoResponse verify V2Channel streams memory leak when server not reply response +// TODO:: bugs to be fixed +func TestMotanCommonEndpoint_AsyncCallNoResponse(t *testing.T) { + url := &motan.URL{Port: 8989, Protocol: "motanV1Compatible"} + url.PutParam(motan.TimeOutKey, "2000") + url.PutParam(motan.ErrorCountThresholdKey, "1") + url.PutParam(motan.ClientConnectionKey, "1") + url.PutParam(motan.AsyncInitConnection, "false") + ep := &MotanCommonEndpoint{} + ep.SetURL(url) + ep.SetSerialization(&serialize.SimpleSerialization{}) + ep.Initialize() + assert.Equal(t, 1, ep.clientConnection) + var resStr string + request := &motan.MotanRequest{ServiceName: "test", Method: "test", RPCContext: &motan.RPCContext{AsyncCall: true, Result: &motan.AsyncResult{Reply: &resStr, Done: make(chan *motan.AsyncResult, 5)}}} + request.Attachment = motan.NewStringMap(0) + // server not reply + request.SetAttachment("no_response", "true") + + res := ep.Call(request) + assert.Nil(t, res.GetException()) + timeoutTimer := time.NewTimer(time.Second * 3) + defer timeoutTimer.Stop() + select { + case <-request.GetRPCContext(false).Result.Done: + t.Errorf("unexpect condition, recv response singnal") + case <-timeoutTimer.C: + t.Logf("expect condition, not recv response singnal") + } + + // Channel.streams can`t release stream + c := <-ep.channels.getChannels() + // it will be zero if server not reply response, bug to be fixed + assert.Equal(t, 1, len(c.streams)) +} + +func TestStreamPool(t *testing.T) { + var oldStream *Stream + // consume stream poll until call New func + for { + oldStream = acquireStream() + if v, ok := oldStream.canRelease.Load().(bool); !ok || !v { + break + } + } + // test new Stream + assert.NotNil(t, oldStream) + assert.NotNil(t, oldStream.recvNotifyCh) + oldStream.streamId = GenerateRequestID() + // verify reset + oldStream.recvNotifyCh <- struct{}{} + + // test canRelease + // oldStream.canRelease is not ture,release fail + // test reset recvNotifyCh + assert.Equal(t, 1, len(oldStream.recvNotifyCh)) + releaseStream(oldStream) + assert.Equal(t, 1, len(oldStream.recvNotifyCh)) + // release success + oldStream.canRelease.Store(true) + releaseStream(oldStream) + assert.Equal(t, 0, len(oldStream.recvNotifyCh)) + + // test put nil + var nilStream *Stream + // can not put nil to pool + releaseStream(nilStream) + newStream3 := acquireStream() + assert.NotEqual(t, nil, newStream3) +} + +func assertChanelStreamEmpty(ep *MotanCommonEndpoint, t *testing.T) { + if ep == nil { + return + } + channels := ep.channels.getChannels() + for { + select { + case c, ok := <-channels: + if !ok || c == nil { + return + } else { + c.streamLock.Lock() + // it should be zero + assert.Equal(t, 0, len(c.streams)) + c.streamLock.Unlock() + + c.heartbeatLock.Lock() + assert.Equal(t, 0, len(c.heartbeats)) + c.heartbeatLock.Unlock() + } + default: + return + } + } +} diff --git a/endpoint/motanEndpoint.go b/endpoint/motanEndpoint.go index 44c0ee50..3ea7fdb4 100644 --- a/endpoint/motanEndpoint.go +++ b/endpoint/motanEndpoint.go @@ -4,16 +4,15 @@ import ( "bufio" "errors" "fmt" + motan "github.com/weibocom/motan-go/core" + "github.com/weibocom/motan-go/log" + mpro "github.com/weibocom/motan-go/protocol" "net" "strconv" "strings" "sync" "sync/atomic" "time" - - motan "github.com/weibocom/motan-go/core" - "github.com/weibocom/motan-go/log" - mpro "github.com/weibocom/motan-go/protocol" ) var ( @@ -32,7 +31,12 @@ var ( defaultAsyncResponse = &motan.MotanResponse{Attachment: motan.NewStringMap(motan.DefaultAttachmentSize), RPCContext: &motan.RPCContext{AsyncCall: true}} - errPanic = errors.New("panic error") + errPanic = errors.New("panic error") + v2StreamPool = sync.Pool{New: func() interface{} { + return &V2Stream{ + recvNotifyCh: make(chan struct{}, 1), + } + }} ) type MotanEndpoint struct { @@ -142,9 +146,8 @@ func (m *MotanEndpoint) Call(request motan.Request) motan.Response { m.recordErrAndKeepalive() return m.defaultErrMotanResponse(request, "motanEndpoint error: channels is null") } - startTime := time.Now().UnixNano() if rc.AsyncCall { - rc.Result.StartTime = startTime + rc.Result.StartTime = time.Now().UnixNano() } // get a channel channel, err := m.channels.Get() @@ -306,16 +309,11 @@ func (m *MotanEndpoint) keepalive() { } func (m *MotanEndpoint) defaultErrMotanResponse(request motan.Request, errMsg string) motan.Response { - response := &motan.MotanResponse{ - RequestID: request.GetRequestID(), - Attachment: motan.NewStringMap(motan.DefaultAttachmentSize), - Exception: &motan.Exception{ - ErrCode: 400, - ErrMsg: errMsg, - ErrType: motan.ServiceException, - }, - } - return response + return motan.BuildExceptionResponse(request.GetRequestID(), &motan.Exception{ + ErrCode: 400, + ErrMsg: errMsg, + ErrType: motan.ServiceException, + }) } func (m *MotanEndpoint) GetName() string { @@ -380,8 +378,9 @@ type V2Channel struct { } type V2Stream struct { - channel *V2Channel - sendMsg *mpro.Message + channel *V2Channel + sendMsg *mpro.Message + streamId uint64 // recv msg recvMsg *mpro.Message recvNotifyCh chan struct{} @@ -389,30 +388,52 @@ type V2Stream struct { deadline time.Time rc *motan.RPCContext - isClose atomic.Value // bool isHeartBeat bool + timer *time.Timer + canRelease atomic.Value // state indicates whether the stream needs to be recycled by the V2Channel +} + +func (s *V2Stream) Reset() { + // try consume + select { + case <-s.recvNotifyCh: + default: + } + s.channel = nil + s.sendMsg = nil + s.recvMsg = nil + s.rc = nil } func (s *V2Stream) Send() error { - timer := time.NewTimer(s.deadline.Sub(time.Now())) - defer timer.Stop() + if s.timer == nil { + s.timer = time.NewTimer(s.deadline.Sub(time.Now())) + } else { + s.timer.Reset(s.deadline.Sub(time.Now())) + } + defer s.timer.Stop() - buf := s.sendMsg.Encode() + s.sendMsg.Encode0() if s.rc != nil && s.rc.Tc != nil { s.rc.Tc.PutReqSpan(&motan.Span{Name: motan.Encode, Addr: s.channel.address, Time: time.Now()}) } - ready := sendReady{data: buf.Bytes()} + ready := sendReady{message: s.sendMsg} select { case s.channel.sendCh <- ready: + // read/write data race if s.rc != nil { sendTime := time.Now() s.rc.RequestSendTime = sendTime if s.rc.Tc != nil { s.rc.Tc.PutReqSpan(&motan.Span{Name: motan.Send, Addr: s.channel.address, Time: sendTime}) } + if s.rc.AsyncCall { + // only return send success, it can release in V2Stream.notify + s.canRelease.Store(true) + } } return nil - case <-timer.C: + case <-s.timer.C: return ErrSendRequestTimeout case <-s.channel.shutdownCh: return ErrChannelShutdown @@ -422,10 +443,18 @@ func (s *V2Stream) Send() error { // Recv sync recv func (s *V2Stream) Recv() (*mpro.Message, error) { defer func() { - s.Close() + // only timeout or shutdown before channel.handleMsg, call RemoveFromChannel will be true, + // which means stream can release + if s.RemoveFromChannel() { + s.canRelease.Store(true) + } }() - timer := time.NewTimer(s.deadline.Sub(time.Now())) - defer timer.Stop() + if s.timer == nil { + s.timer = time.NewTimer(s.deadline.Sub(time.Now())) + } else { + s.timer.Reset(s.deadline.Sub(time.Now())) + } + defer s.timer.Stop() select { case <-s.recvNotifyCh: msg := s.recvMsg @@ -433,17 +462,18 @@ func (s *V2Stream) Recv() (*mpro.Message, error) { return nil, errors.New("recv err: recvMsg is nil") } return msg, nil - case <-timer.C: + case <-s.timer.C: + // stream may be referenced by V2Stream.notify, can`t release + s.canRelease.Store(false) return nil, ErrRecvRequestTimeout case <-s.channel.shutdownCh: + // stream may be referenced by V2Stream.notify, can`t release + s.canRelease.Store(false) return nil, ErrChannelShutdown } } func (s *V2Stream) notify(msg *mpro.Message, t time.Time) { - defer func() { - s.Close() - }() if s.rc != nil { s.rc.ResponseReceiveTime = t if s.rc.Tc != nil { @@ -451,6 +481,10 @@ func (s *V2Stream) notify(msg *mpro.Message, t time.Time) { s.rc.Tc.PutResSpan(&motan.Span{Name: motan.Decode, Time: time.Now()}) } if s.rc.AsyncCall { + defer func() { + s.RemoveFromChannel() + releaseV2Stream(s) + }() msg.Header.SetProxy(s.rc.Proxy) result := s.rc.Result response, err := mpro.ConvertToResponse(msg, s.channel.serialization) @@ -463,7 +497,7 @@ func (s *V2Stream) notify(msg *mpro.Message, t time.Time) { if err = response.ProcessDeserializable(result.Reply); err != nil { result.Error = err } - response.SetProcessTime(int64((time.Now().UnixNano() - result.StartTime) / 1000000)) + response.SetProcessTime((time.Now().UnixNano() - result.StartTime) / 1000000) if s.rc.Tc != nil { s.rc.Tc.PutResSpan(&motan.Span{Name: motan.Convert, Addr: s.channel.address, Time: time.Now()}) } @@ -480,23 +514,25 @@ func (s *V2Stream) SetDeadline(deadline time.Duration) { s.deadline = time.Now().Add(deadline) } -func (c *V2Channel) NewStream(msg *mpro.Message, rc *motan.RPCContext) (*V2Stream, error) { +func (c *V2Channel) newStream(msg *mpro.Message, rc *motan.RPCContext) (*V2Stream, error) { if msg == nil || msg.Header == nil { return nil, errors.New("msg is invalid") } if c.IsClosed() { return nil, ErrChannelShutdown } - s := &V2Stream{ - channel: c, - sendMsg: msg, - recvNotifyCh: make(chan struct{}, 1), - deadline: time.Now().Add(1 * time.Second), - rc: rc, + + s := acquireV2Stream() + s.channel = c + s.sendMsg = msg + if s.recvNotifyCh == nil { + s.recvNotifyCh = make(chan struct{}, 1) } - s.isClose.Store(false) + s.deadline = time.Now().Add(1 * time.Second) + s.rc = rc // RequestID is communication identifier, it is own by channel msg.Header.RequestID = GenerateRequestID() + s.streamId = msg.Header.RequestID if msg.Header.IsHeartbeat() { c.heartbeatLock.Lock() c.heartbeats[msg.Header.RequestID] = s @@ -506,36 +542,52 @@ func (c *V2Channel) NewStream(msg *mpro.Message, rc *motan.RPCContext) (*V2Strea c.streamLock.Lock() c.streams[msg.Header.RequestID] = s c.streamLock.Unlock() + s.isHeartBeat = false + } + if rc != nil && rc.AsyncCall { // release by Stream.Notify + s.canRelease.Store(false) + } else { // release by Channel self + s.canRelease.Store(true) } return s, nil } -func (s *V2Stream) Close() { - if !s.isClose.Load().(bool) { - if s.isHeartBeat { - s.channel.heartbeatLock.Lock() - delete(s.channel.heartbeats, s.sendMsg.Header.RequestID) - s.channel.heartbeatLock.Unlock() - } else { - s.channel.streamLock.Lock() - delete(s.channel.streams, s.sendMsg.Header.RequestID) - s.channel.streamLock.Unlock() +func (s *V2Stream) RemoveFromChannel() bool { + var exist bool + if s.isHeartBeat { + s.channel.heartbeatLock.Lock() + if _, exist = s.channel.heartbeats[s.streamId]; exist { + delete(s.channel.heartbeats, s.streamId) } - s.isClose.Store(true) + s.channel.heartbeatLock.Unlock() + } else { + s.channel.streamLock.Lock() + if _, exist = s.channel.streams[s.streamId]; exist { + delete(s.channel.streams, s.streamId) + } + s.channel.streamLock.Unlock() } + return exist } type sendReady struct { - data []byte + message *mpro.Message // motan2 protocol message + v1Message []byte //motan1 protocol, synchronize subsequent if motan1 protocol message optimization } func (c *V2Channel) Call(msg *mpro.Message, deadline time.Duration, rc *motan.RPCContext) (*mpro.Message, error) { - stream, err := c.NewStream(msg, rc) + stream, err := c.newStream(msg, rc) if err != nil { return nil, err } + defer func() { + if rc == nil || !rc.AsyncCall { + releaseV2Stream(stream) + } + }() + stream.SetDeadline(deadline) - if err := stream.Send(); err != nil { + if err = stream.Send(); err != nil { return nil, err } if rc != nil && rc.AsyncCall { @@ -558,8 +610,9 @@ func (c *V2Channel) recv() { } func (c *V2Channel) recvLoop() error { + decodeBuf := make([]byte, mpro.DefaultBufferSize) for { - res, t, err := mpro.DecodeWithTime(c.bufRead, c.config.MaxContentLength) + res, t, err := mpro.DecodeWithTime(c.bufRead, &decodeBuf, c.config.MaxContentLength) if err != nil { return err } @@ -583,18 +636,18 @@ func (c *V2Channel) send() { for { select { case ready := <-c.sendCh: - if ready.data != nil { + if ready.message != nil { + // can`t reuse net.Buffers + // len and cap will be 0 after writev consume, 'out of range panic' will happen when reuse + var sendBuf net.Buffers = ready.message.GetEncodedBytes() // TODO need async? c.conn.SetWriteDeadline(time.Now().Add(motan.DefaultWriteTimeout)) - sent := 0 - for sent < len(ready.data) { - n, err := c.conn.Write(ready.data[sent:]) - if err != nil { - vlog.Errorf("Failed to write channel. ep: %s, err: %s", c.address, err.Error()) - c.closeOnErr(err) - return - } - sent += n + _, err := sendBuf.WriteTo(c.conn) + ready.message.SetCanRelease() + if err != nil { + vlog.Errorf("Failed to write channel. ep: %s, err: %s", c.address, err.Error()) + c.closeOnErr(err) + return } } case <-c.shutdownCh: @@ -606,6 +659,7 @@ func (c *V2Channel) send() { func (c *V2Channel) handleHeartbeat(msg *mpro.Message, t time.Time) error { c.heartbeatLock.Lock() stream := c.heartbeats[msg.Header.RequestID] + delete(c.heartbeats, msg.Header.RequestID) c.heartbeatLock.Unlock() if stream == nil { vlog.Warningf("handle heartbeat message, missing stream: %d, ep:%s", msg.Header.RequestID, c.address) @@ -618,6 +672,7 @@ func (c *V2Channel) handleHeartbeat(msg *mpro.Message, t time.Time) error { func (c *V2Channel) handleMessage(msg *mpro.Message, t time.Time) error { c.streamLock.Lock() stream := c.streams[msg.Header.RequestID] + delete(c.streams, msg.Header.RequestID) c.streamLock.Unlock() if stream == nil { vlog.Warningf("handle recv message, missing stream: %d, ep:%s", msg.Header.RequestID, c.address) @@ -799,3 +854,16 @@ func GetDefaultMotanEPAsynInit() bool { } return res.(bool) } + +func acquireV2Stream() *V2Stream { + return v2StreamPool.Get().(*V2Stream) +} + +func releaseV2Stream(stream *V2Stream) { + if stream != nil { + if v, ok := stream.canRelease.Load().(bool); ok && v { + stream.Reset() + v2StreamPool.Put(stream) + } + } +} diff --git a/endpoint/motanEndpoint_test.go b/endpoint/motanEndpoint_test.go index 9bb0efe6..acb98912 100644 --- a/endpoint/motanEndpoint_test.go +++ b/endpoint/motanEndpoint_test.go @@ -21,7 +21,7 @@ func TestMain(m *testing.M) { m.Run() } -//TODO more UT +// TODO more UT func TestGetName(t *testing.T) { url := &motan.URL{Port: 8989, Protocol: "motan2"} url.PutParam(motan.TimeOutKey, "100") @@ -57,6 +57,8 @@ func TestRecordErrEmptyThreshold(t *testing.T) { ep.Call(request) assert.True(t, ep.IsAvailable()) } + + assertV2ChanelStreamEmpty(ep, t) ep.Destroy() } @@ -89,6 +91,8 @@ func TestRecordErrWithErrThreshold(t *testing.T) { ep.channels.channels <- buildV2Channel(conn, ep.channels.config, ep.channels.serialization) time.Sleep(time.Second * 2) //assert.True(t, ep.IsAvailable()) + + assertV2ChanelStreamEmpty(ep, t) ep.Destroy() } @@ -111,6 +115,8 @@ func TestMotanEndpoint_SuccessCall(t *testing.T) { s, ok := v.(string) assert.True(t, ok) assert.Equal(t, s, "hello") + + assertV2ChanelStreamEmpty(ep, t) ep.Destroy() } @@ -133,6 +139,8 @@ func TestMotanEndpoint_AsyncCall(t *testing.T) { resp := <-request.GetRPCContext(false).Result.Done assert.Nil(t, resp.Error) assert.Equal(t, resStr, "hello") + + assertV2ChanelStreamEmpty(ep, t) ep.Destroy() } @@ -158,6 +166,8 @@ func TestMotanEndpoint_ErrorCall(t *testing.T) { ep.Call(request) time.Sleep(1 * time.Millisecond) assert.Equal(t, beforeNGoroutine, runtime.NumGoroutine()) + + assertV2ChanelStreamEmpty(ep, t) ep.Destroy() } @@ -184,6 +194,8 @@ func TestMotanEndpoint_RequestTimeout(t *testing.T) { ep.Call(request) time.Sleep(1 * time.Millisecond) assert.Equal(t, beforeNGoroutine, runtime.NumGoroutine()) + + assertV2ChanelStreamEmpty(ep, t) ep.Destroy() } @@ -211,6 +223,8 @@ func TestLazyInit(t *testing.T) { ep.Call(request) time.Sleep(1 * time.Millisecond) assert.Equal(t, beforeNGoroutine, runtime.NumGoroutine()) + + assertV2ChanelStreamEmpty(ep, t) ep.Destroy() } @@ -227,6 +241,104 @@ func TestAsyncInit(t *testing.T) { time.Sleep(time.Second * 5) } +// TestMotanEndpoint_AsyncCallNoResponse verify V2Channel streams memory leak when server not reply response +// bugs to be fixed +func TestMotanEndpoint_AsyncCallNoResponse(t *testing.T) { + url := &motan.URL{Port: 8989, Protocol: "motan2"} + url.PutParam(motan.TimeOutKey, "2000") + url.PutParam(motan.ErrorCountThresholdKey, "1") + url.PutParam(motan.ClientConnectionKey, "1") + url.PutParam(motan.AsyncInitConnection, "false") + ep := &MotanEndpoint{} + ep.SetURL(url) + ep.SetSerialization(&serialize.SimpleSerialization{}) + ep.Initialize() + assert.Equal(t, 1, ep.clientConnection) + var resStr string + request := &motan.MotanRequest{ServiceName: "test", Method: "test", RPCContext: &motan.RPCContext{AsyncCall: true, Result: &motan.AsyncResult{Reply: &resStr, Done: make(chan *motan.AsyncResult, 5)}}} + request.Attachment = motan.NewStringMap(0) + + // server not reply + request.SetAttachment("no_response", "true") + + res := ep.Call(request) + assert.Nil(t, res.GetException()) + timeoutTimer := time.NewTimer(time.Second * 3) + defer timeoutTimer.Stop() + select { + case <-request.GetRPCContext(false).Result.Done: + t.Errorf("unexpect condition, recv response singnal") + case <-timeoutTimer.C: + t.Logf("expect condition, not recv response singnal") + } + + // Channel.streams cant`t release stream + c := <-ep.channels.getChannels() + // it will be zero if server not reply response, bug to be fixed + assert.Equal(t, 1, len(c.streams)) +} + +func TestV2StreamPool(t *testing.T) { + var oldStream *V2Stream + // consume v2stream poll until call New func + for { + oldStream = acquireV2Stream() + if v, ok := oldStream.canRelease.Load().(bool); !ok || !v { + break + } + } + // test new Stream + assert.NotNil(t, oldStream) + assert.NotNil(t, oldStream.recvNotifyCh) + oldStream.streamId = GenerateRequestID() + // verify reset + oldStream.recvNotifyCh <- struct{}{} + + // test canRelease + // oldStream.canRelease is not ture,release fail + // test reset recvNotifyCh + assert.Equal(t, 1, len(oldStream.recvNotifyCh)) + releaseV2Stream(oldStream) + assert.Equal(t, 1, len(oldStream.recvNotifyCh)) + // canRelease success + oldStream.canRelease.Store(true) + releaseV2Stream(oldStream) + assert.Equal(t, 0, len(oldStream.recvNotifyCh)) + + // test put nil + var nilStream *V2Stream + // can not put nil to pool + releaseV2Stream(nilStream) + newStream3 := acquireV2Stream() + assert.NotEqual(t, nil, newStream3) +} + +func assertV2ChanelStreamEmpty(ep *MotanEndpoint, t *testing.T) { + if ep == nil { + return + } + channels := ep.channels.getChannels() + for { + select { + case c, ok := <-channels: + if !ok || c == nil { + return + } else { + c.streamLock.Lock() + // it should be zero + assert.Equal(t, 0, len(c.streams)) + c.streamLock.Unlock() + + c.heartbeatLock.Lock() + assert.Equal(t, 0, len(c.heartbeats)) + c.heartbeatLock.Unlock() + } + default: + return + } + } +} + func StartTestServer(port int) *MockServer { m := &MockServer{Port: port} m.Start() @@ -268,8 +380,9 @@ func handle(netListen net.Listener) { } func handleConnection(conn net.Conn, timeout int) { - buf := bufio.NewReader(conn) - msg, _, err := protocol.DecodeWithTime(buf, 10*1024*1024) + reader := bufio.NewReader(conn) + decodeBuf := make([]byte, 100) + msg, err := protocol.Decode(reader, &decodeBuf) if err != nil { time.Sleep(time.Millisecond * 1000) conn.Close() @@ -279,6 +392,10 @@ func handleConnection(conn net.Conn, timeout int) { } func processMsg(msg *protocol.Message, conn net.Conn) { + // mock async call, but server not reply response + if _, ok := msg.Metadata.Load("no_response"); ok { + return + } var res *protocol.Message var tc *motan.TraceContext var err error diff --git a/filter/accessLog.go b/filter/accessLog.go index 06c1e6e7..ee980087 100644 --- a/filter/accessLog.go +++ b/filter/accessLog.go @@ -2,11 +2,10 @@ package filter import ( "encoding/json" - "strconv" - "time" - motan "github.com/weibocom/motan-go/core" "github.com/weibocom/motan-go/log" + "strconv" + "time" ) const ( @@ -34,15 +33,17 @@ func (t *AccessLogFilter) NewFilter(url *motan.URL) motan.Filter { func (t *AccessLogFilter) Filter(caller motan.Caller, request motan.Request) motan.Response { role := defaultRole var ip string + var start time.Time switch caller.(type) { case motan.Provider: role = serverAgentRole ip = request.GetAttachment(motan.HostKey) + start = request.GetRPCContext(true).RequestReceiveTime case motan.EndPoint: role = clientAgentRole ip = caller.GetURL().Host + start = time.Now() } - start := time.Now() response := t.GetNext().Filter(caller, request) address := ip + ":" + caller.GetURL().GetPortStr() if _, ok := caller.(motan.Provider); ok { @@ -81,9 +82,6 @@ func doAccessLog(filterName string, role string, address string, totalTime int64 // response code should be same as upstream responseCode := "" metaUpstreamCode, _ := response.GetAttachments().Load(motan.MetaUpstreamCode) - if resCtx.Meta != nil { - responseCode = resCtx.Meta.LoadOrEmpty(motan.MetaUpstreamCode) - } var exceptionData []byte if exception != nil { exceptionData, _ = json.Marshal(exception) @@ -94,21 +92,23 @@ func doAccessLog(filterName string, role string, address string, totalTime int64 responseCode = "200" } } - vlog.AccessLog(&vlog.AccessLogEntity{ - FilterName: filterName, - Role: role, - RequestID: response.GetRequestID(), - Service: request.GetServiceName(), - Method: request.GetMethod(), - RemoteAddress: address, - Desc: request.GetMethodDesc(), - ReqSize: reqCtx.BodySize, - ResSize: resCtx.BodySize, - BizTime: response.GetProcessTime(), //ms - TotalTime: totalTime, //ms - ResponseCode: responseCode, - Success: exception == nil, - Exception: string(exceptionData), - UpstreamCode: metaUpstreamCode, - }) + + logEntity := vlog.AcquireAccessLogEntity() + logEntity.FilterName = filterName + logEntity.Role = role + logEntity.RequestID = response.GetRequestID() + logEntity.Service = request.GetServiceName() + logEntity.Method = request.GetMethod() + logEntity.RemoteAddress = address + logEntity.Desc = request.GetMethodDesc() + logEntity.ReqSize = reqCtx.BodySize + logEntity.ResSize = resCtx.BodySize + logEntity.BizTime = response.GetProcessTime() //ms + logEntity.TotalTime = totalTime //ms + logEntity.ResponseCode = responseCode + logEntity.Success = exception == nil + logEntity.Exception = string(exceptionData) + logEntity.UpstreamCode = metaUpstreamCode + + vlog.AccessLog(logEntity) } diff --git a/filter/clusterMetrics.go b/filter/clusterMetrics.go index bbbbcfbe..8019f3ce 100644 --- a/filter/clusterMetrics.go +++ b/filter/clusterMetrics.go @@ -4,7 +4,6 @@ import ( "time" motan "github.com/weibocom/motan-go/core" - "github.com/weibocom/motan-go/metrics" "github.com/weibocom/motan-go/protocol" ) @@ -57,11 +56,8 @@ func (c *ClusterMetricsFilter) Filter(haStrategy motan.HaStrategy, loadBalance m if ctx != nil && ctx.Proxy { role = "motan-client-agent" } - key := metrics.Escape(role) + - ":" + metrics.Escape(request.GetAttachment(protocol.MSource)) + - ":" + metrics.Escape(request.GetMethod()) - addMetric(metrics.Escape(request.GetAttachment(protocol.MGroup))+".cluster", - metrics.Escape(request.GetAttachment(protocol.MPath)), - key, time.Since(start).Nanoseconds()/1e6, response) + keys := []string{role, request.GetAttachment(protocol.MSource), request.GetMethod()} + addMetricWithKeys(request.GetAttachment(protocol.MGroup), ".cluster", + request.GetAttachment(protocol.MPath), keys, time.Since(start).Nanoseconds()/1e6, response) return response } diff --git a/filter/filter.go b/filter/filter.go index 0126f291..38552ac8 100644 --- a/filter/filter.go +++ b/filter/filter.go @@ -2,6 +2,7 @@ package filter import ( motan "github.com/weibocom/motan-go/core" + "time" ) // ext name @@ -59,3 +60,14 @@ func RegistDefaultFilters(extFactory motan.ExtensionFactory) { return &ClusterCircuitBreakerFilter{} }) } + +func getFilterStartTime(caller motan.Caller, request motan.Request) time.Time { + switch caller.(type) { + case motan.Provider: + return request.GetRPCContext(true).RequestReceiveTime + case motan.EndPoint: + return time.Now() + default: + return time.Now() + } +} diff --git a/filter/metrics.go b/filter/metrics.go index 848eac88..c69b41fd 100644 --- a/filter/metrics.go +++ b/filter/metrics.go @@ -1,11 +1,10 @@ package filter import ( - "time" - motan "github.com/weibocom/motan-go/core" "github.com/weibocom/motan-go/metrics" "github.com/weibocom/motan-go/protocol" + "time" ) const ( @@ -57,9 +56,8 @@ func (m *MetricsFilter) GetNext() motan.EndPointFilter { } func (m *MetricsFilter) Filter(caller motan.Caller, request motan.Request) motan.Response { - start := time.Now() + start := getFilterStartTime(caller, request) response := m.GetNext().Filter(caller, request) - proxy := false provider := false ctx := request.GetRPCContext(false) @@ -86,30 +84,28 @@ func (m *MetricsFilter) Filter(caller motan.Caller, request motan.Request) motan if provider { application = caller.GetURL().GetParam(motan.ApplicationKey, "") } - key := metrics.Escape(role) + - ":" + metrics.Escape(application) + - ":" + metrics.Escape(request.GetMethod()) - addMetric(metrics.Escape(request.GetAttachment(protocol.MGroup)), - metrics.Escape(request.GetAttachment(protocol.MPath)), - key, time.Since(start).Nanoseconds()/1e6, response) + keys := []string{role, application, request.GetMethod()} + addMetricWithKeys(request.GetAttachment(protocol.MGroup), "", request.GetAttachment(protocol.MPath), + keys, time.Since(start).Nanoseconds()/1e6, response) return response } -func addMetric(group string, service string, key string, cost int64, response motan.Response) { - metrics.AddCounter(group, service, key+MetricsTotalCountSuffix, 1) //total_count - if response.GetException() != nil { //err_count +// addMetricWithKeys arguments: group & groupSuffix & service & keys elements is text without escaped +func addMetricWithKeys(group, groupSuffix string, service string, keys []string, cost int64, response motan.Response) { + metrics.AddCounterWithKeys(group, groupSuffix, service, keys, MetricsTotalCountSuffix, 1) //total_count + if response.GetException() != nil { //err_count exception := response.GetException() if exception.ErrType == motan.BizException { - metrics.AddCounter(group, service, key+MetricsBizErrorCountSuffix, 1) + metrics.AddCounterWithKeys(group, groupSuffix, service, keys, MetricsBizErrorCountSuffix, 1) } else { - metrics.AddCounter(group, service, key+MetricsOtherErrorCountSuffix, 1) + metrics.AddCounterWithKeys(group, groupSuffix, service, keys, MetricsOtherErrorCountSuffix, 1) } } - metrics.AddCounter(group, service, key+metrics.ElapseTimeSuffix(cost), 1) + metrics.AddCounterWithKeys(group, groupSuffix, service, keys, metrics.ElapseTimeSuffix(cost), 1) if cost > 200 { - metrics.AddCounter(group, service, key+MetricsSlowCountSuffix, 1) + metrics.AddCounterWithKeys(group, groupSuffix, service, keys, MetricsSlowCountSuffix, 1) } - metrics.AddHistograms(group, service, key, cost) + metrics.AddHistogramsWithKeys(group, groupSuffix, service, keys, "", cost) } func (m *MetricsFilter) SetContext(context *motan.Context) { diff --git a/filter/metrics_test.go b/filter/metrics_test.go index f643882f..63cc5cc9 100644 --- a/filter/metrics_test.go +++ b/filter/metrics_test.go @@ -31,7 +31,7 @@ func TestMetricsFilter(t *testing.T) { request.GetRPCContext(true).Proxy = true request.SetAttachment(protocol.MSource, application) request.SetAttachment(protocol.MPath, testService) - assert.Nil(t, metrics.GetStatItem(testGroup, testService), "metric stat") + assert.Nil(t, metrics.GetStatItem(testGroup, "", testService), "metric stat") ep := factory.GetEndPoint(url) provider := factory.GetProvider(url) @@ -41,25 +41,28 @@ func TestMetricsFilter(t *testing.T) { name string caller motan.Caller request motan.Request - key string + keys []string }{ - {name: "proxyClient", caller: ep, request: request, key: "motan-client-agent:" + metrics.Escape(application) + ":" + metrics.Escape(testMethod)}, - {name: "proxyServer", caller: provider, request: request, key: "motan-server-agent:" + metrics.Escape(application) + ":" + metrics.Escape(testMethod)}, - {name: "Client", caller: ep, request: request2, key: "motan-client:" + metrics.Escape(application) + ":" + metrics.Escape(testMethod)}, - {name: "Server", caller: provider, request: request2, key: "motan-server:" + metrics.Escape(application) + ":" + metrics.Escape(testMethod)}, + {name: "proxyClient", caller: ep, request: request, keys: []string{"motan-client-agent", application, testMethod}}, + {name: "proxyServer", caller: provider, request: request, keys: []string{"motan-server-agent", application, testMethod}}, + {name: "Client", caller: ep, request: request2, keys: []string{"motan-client", application, testMethod}}, + {name: "Server", caller: provider, request: request2, keys: []string{"motan-server", application, testMethod}}, + } + var getKeysStr = func(keys []string) string { + return metrics.Escape(keys[0]) + ":" + metrics.Escape(keys[1]) + ":" + metrics.Escape(keys[2]) } for _, test := range tests { t.Run(test.name, func(t *testing.T) { mf.Filter(test.caller, test.request) time.Sleep(10 * time.Millisecond) // The metrics filter has do escape - assert.Equal(t, 1, int(metrics.GetStatItem(metrics.Escape(testGroup), metrics.Escape(testService)).SnapshotAndClear().Count(test.key+MetricsTotalCountSuffix)), "metric count") + assert.Equal(t, 1, int(metrics.GetStatItem(testGroup, "", testService).SnapshotAndClear().Count(getKeysStr(test.keys)+MetricsTotalCountSuffix)), "metric count") }) } } func TestAddMetric(t *testing.T) { - key := "motan-client-agent:testApplication:" + testMethod + keys := []string{"motan-client-agent", "testApplication", testMethod} factory := initFactory() mf := factory.GetFilter(Metrics).(motan.EndPointFilter) mf.(*MetricsFilter).SetContext(&motan.Context{Config: config.NewConfig()}) @@ -67,7 +70,9 @@ func TestAddMetric(t *testing.T) { response2 := &motan.MotanResponse{ProcessTime: 100, Exception: &motan.Exception{ErrType: motan.BizException}} response3 := &motan.MotanResponse{ProcessTime: 100, Exception: &motan.Exception{ErrType: motan.FrameworkException}} response4 := &motan.MotanResponse{ProcessTime: 1000} - + var getKeysStr = func(keys []string) string { + return metrics.Escape(keys[0]) + ":" + metrics.Escape(keys[1]) + ":" + metrics.Escape(keys[2]) + } tests := []struct { name string response motan.Response @@ -81,11 +86,11 @@ func TestAddMetric(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - addMetric(testGroup, testService, key, test.response.GetProcessTime(), test.response) + addMetricWithKeys(testGroup, "", testService, keys, test.response.GetProcessTime(), test.response) time.Sleep(10 * time.Millisecond) - snap := metrics.GetStatItem(testGroup, testService).SnapshotAndClear() + snap := metrics.GetStatItem(testGroup, "", testService).SnapshotAndClear() for _, k := range test.keys { - assert.True(t, snap.Count(key+k) > 0, fmt.Sprintf("key '%s'", k)) + assert.True(t, snap.Count(getKeysStr(keys)+k) > 0, fmt.Sprintf("key '%s'", k)) } }) } diff --git a/filter/rateLimit.go b/filter/rateLimit.go index 85d5ac22..e98692b1 100644 --- a/filter/rateLimit.go +++ b/filter/rateLimit.go @@ -137,20 +137,20 @@ func (r *RateLimitFilter) GetType() int32 { func getKeyValue(key, value, prefix string) (string, float64, bool) { if strings.HasPrefix(key, prefix) { if temp := strings.Split(key, prefix); len(temp) == 2 { - if r, err := strconv.ParseFloat(value, 64); err == nil && temp[1] != "" && r > 0 { + r, err := strconv.ParseFloat(value, 64) + if err == nil && temp[1] != "" && r > 0 { return temp[1], r, true + } + if err != nil { + vlog.Warningf("[rateLimit] parse %s config error:%s", key, err.Error()) } else { - if err != nil { - vlog.Warningf("[rateLimit] parse %s config error:%s", key, err.Error()) - } else { - if r <= 0 { - vlog.Warningf("[rateLimit] parse %s config error: value is 0 or negative", key) - } - } - if temp[1] == "" { - vlog.Warningf("[rateLimit] parse %s config error: key is empty", key) + if r <= 0 { + vlog.Warningf("[rateLimit] parse %s config error: value is 0 or negative", key) } } + if temp[1] == "" { + vlog.Warningf("[rateLimit] parse %s config error: key is empty", key) + } } } return "", 0, false diff --git a/filter/tracing.go b/filter/tracing.go index ac602c85..aa855087 100644 --- a/filter/tracing.go +++ b/filter/tracing.go @@ -65,48 +65,48 @@ func tracingFunc() func(span ot.Span, data *CallData) { // As described by OpenTracing, for a single call from client to server, both sides will start a span, // with the server side span to be child of the client side. Described as following // -// call -// caller ---------------> callee -// [span1] [span2] +// call +// caller ---------------> callee +// [span1] [span2] // // here [span1] is parent of [span2]. // // When this filter is applied, it will filter both the incoming and // outgoing requests to record trace information. The following diagram is a demonstration. // -// filter -// +---------+ -// | | [span1] -// [span2] *-----+-- in <--+--------------------- | user | -// | | | -// V | | -// | ------- | | -// pass-thru | service | | -// [span2] V ------- | | -// | | | -// | | | [span3] -// [span2] *-----+-> out --+--------------------> | dep | -// | | -// +---------+ +// filter +// +---------+ +// | | [span1] +// [span2] *-----+-- in <--+--------------------- | user | +// | | | +// V | | +// | ------- | | +// pass-thru | service | | +// [span2] V ------- | | +// | | | +// | | | [span3] +// [span2] *-----+-> out --+--------------------> | dep | +// | | +// +---------+ // // When the filter receives an incoming request, it will: // -// 1. extract span context from request (will get [span1]) -// 2. start a child span of the extracted span ([span2], child of [span1]) -// 3. forward the request with [span2] to the service +// 1. extract span context from request (will get [span1]) +// 2. start a child span of the extracted span ([span2], child of [span1]) +// 3. forward the request with [span2] to the service // // Then the service may make an outgoing request to some dependent services, // it should pass-through the span information ([span2]). // The filter will receive the outgoing request with [span2], then it will. // -// 1. extract span context from the outgoing request (it should the [span2]) -// 2. start a child span of the extracted span ([span3], child of [span2]) -// 3. forward the request with [span3] to the dependent service +// 1. extract span context from the outgoing request (it should the [span2]) +// 2. start a child span of the extracted span ([span3], child of [span2]) +// 3. forward the request with [span3] to the dependent service // // So here // -// (parent) (parent) -// [span1] <------ [span2] <------ [span3] +// (parent) (parent) +// [span1] <------ [span2] <------ [span3] // // NOTE: // diff --git a/go.mod b/go.mod index 989f8b55..e7623234 100644 --- a/go.mod +++ b/go.mod @@ -15,12 +15,13 @@ require ( github.com/kr/pretty v0.3.1 // indirect github.com/mitchellh/mapstructure v1.1.2 github.com/opentracing/opentracing-go v1.0.2 + github.com/panjf2000/ants/v2 v2.9.0 github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a github.com/samuel/go-zookeeper v0.0.0-20180130194729-c4fab1ac1bec github.com/shirou/gopsutil/v3 v3.21.9 github.com/smartystreets/goconvey v1.6.4 // indirect - github.com/stretchr/testify v1.7.0 + github.com/stretchr/testify v1.8.2 github.com/valyala/fasthttp v1.2.0 github.com/weibreeze/breeze-go v0.1.1 go.uber.org/atomic v1.4.0 // indirect diff --git a/ha/backupRequestHA.go b/ha/backupRequestHA.go index d50c6e5b..8ee949d9 100644 --- a/ha/backupRequestHA.go +++ b/ha/backupRequestHA.go @@ -67,7 +67,7 @@ func (br *BackupRequestHA) Call(request motan.Request, loadBalance motan.LoadBal successCh := make(chan motan.Response, retries+1) if delay <= 0 { //no delay time configuration // TODO: we should use metrics of the cluster, with traffic control the group may changed - item := metrics.GetStatItem(metrics.Escape(request.GetAttachment(protocol.MGroup)), metrics.Escape(request.GetAttachment(protocol.MPath))) + item := metrics.GetStatItem(request.GetAttachment(protocol.MGroup), "", request.GetAttachment(protocol.MPath)) if item == nil || item.LastSnapshot() == nil { initDelay := int(br.url.GetMethodPositiveIntValue(request.GetMethod(), request.GetMethodDesc(), "backupRequestInitDelayTime", 0)) if initDelay == 0 { diff --git a/ha/failoverHA.go b/ha/failoverHA.go index 1701cf50..e5d20bab 100644 --- a/ha/failoverHA.go +++ b/ha/failoverHA.go @@ -18,12 +18,15 @@ type FailOverHA struct { func (f *FailOverHA) GetName() string { return FailOver } + func (f *FailOverHA) GetURL() *motan.URL { return f.url } + func (f *FailOverHA) SetURL(url *motan.URL) { f.url = url } + func (f *FailOverHA) Call(request motan.Request, loadBalance motan.LoadBalance) motan.Response { retries := f.url.GetMethodPositiveIntValue(request.GetMethod(), request.GetMethodDesc(), motan.RetriesKey, defaultRetries) var lastErr *motan.Exception diff --git a/http/httpProxy.go b/http/httpProxy.go index dc048722..33738633 100644 --- a/http/httpProxy.go +++ b/http/httpProxy.go @@ -4,10 +4,12 @@ import ( "bytes" "errors" "fmt" + mpro "github.com/weibocom/motan-go/protocol" "net/url" "reflect" "regexp" "strings" + "sync" "github.com/valyala/fasthttp" "github.com/weibocom/motan-go/core" @@ -21,8 +23,10 @@ const ( ) const ( - HeaderPrefix = "http_" - HeaderContentType = "Content-Type" + HeaderPrefix = "http_" + HeaderContentType = "Content-Type" + HeaderContentTypeLowerCase = "content-type" + HeaderHost = "Host" ) const ( @@ -60,8 +64,20 @@ const ( var ( WhitespaceSplitPattern = regexp.MustCompile(`\s+`) findRewriteVarPattern = regexp.MustCompile(`\{[0-9a-zA-Z_-]+\}`) - httpProxySpecifiedAttachments = []string{Proxy, Method, QueryString, core.HostKey} - rewriteVarFunc = func(condType ProxyRewriteType, uri string, queryBytes []byte) string { + httpProxySpecifiedAttachments = []string{Proxy, Method, QueryString, core.HostKey, HeaderHost} + InnerAttachmentsConvertMap = map[string]string{ + mpro.MPath: "MOTAN-p", + mpro.MMethod: "MOTAN-m", + mpro.MMethodDesc: "MOTAN-md", + mpro.MGroup: "MOTAN-g", + mpro.MProxyProtocol: "MOTAN-pp", + mpro.MVersion: "MOTAN-v", + mpro.MModule: "MOTAN-mdu", + mpro.MSource: "MOTAN-s", + mpro.MRequestID: "MOTAN-rid", + mpro.MTimeout: "MOTAN-tmo", + } + rewriteVarFunc = func(condType ProxyRewriteType, uri string, queryBytes []byte) string { if condType != proxyRewriteTypeRegexpVar || len(queryBytes) == 0 { return uri } @@ -73,8 +89,43 @@ var ( } return uri } + httpResponsePool = sync.Pool{New: func() interface{} { + res := &HttpMotanResponse{} + res.RPCContext = &core.RPCContext{} + return res + }} ) +// HttpMotanResponse Resposne used in http provider to deal with http response +type HttpMotanResponse struct { + core.MotanResponse + HttpResponse *fasthttp.Response +} + +func (m *HttpMotanResponse) Reset() { + m.RequestID = 0 + m.Value = nil + m.Exception = nil + m.ProcessTime = 0 + m.Attachment = nil + m.RPCContext.Reset() + m.HttpResponse = nil +} + +func AcquireHttpMotanResponse() *HttpMotanResponse { + return httpResponsePool.Get().(*HttpMotanResponse) +} + +func ReleaseHttpMotanResponse(m *HttpMotanResponse) { + if m != nil { + if m.HttpResponse != nil { + fasthttp.ReleaseResponse(m.HttpResponse) + } + m.Reset() + httpResponsePool.Put(m) + } +} + func PatternSplit(s string, pattern *regexp.Regexp) []string { matches := pattern.FindAllStringIndex(s, -1) strings := make([]string, 0, len(matches)) @@ -431,25 +482,31 @@ func MotanRequestToFasthttpRequest(motanRequest core.Request, fasthttpRequest *f fasthttpRequest.URI().SetQueryString(queryString) } motanRequest.GetAttachments().Range(func(k, v string) bool { + if convertK, ok := InnerAttachmentsConvertMap[k]; ok { + k = convertK + fasthttpRequest.Header.Set(k, v) + return true + } // ignore some specified key for _, attachmentKey := range httpProxySpecifiedAttachments { - if strings.EqualFold(k, attachmentKey) { + if k == attachmentKey { return true } } // fasthttp will use a special field to store this header - if strings.EqualFold(k, HeaderPrefix+core.HostKey) { + if k == HeaderPrefix+core.HostKey || k == HeaderPrefix+HeaderHost { fasthttpRequest.Header.SetHost(v) return true } - if strings.EqualFold(k, HeaderContentType) { + if k == HeaderContentType || k == HeaderContentTypeLowerCase { fasthttpRequest.Header.SetContentType(v) return true } - k = strings.Replace(k, "M_", "MOTAN-", -1) // http header private prefix - k = strings.Replace(k, HeaderPrefix, "", -1) - fasthttpRequest.Header.Add(k, v) + if strings.HasPrefix(k, HeaderPrefix) { + k = strings.Replace(k, HeaderPrefix, "", 1) + } + fasthttpRequest.Header.Set(k, v) return true }) fasthttpRequest.Header.Del("Connection") @@ -501,12 +558,11 @@ func FasthttpResponseToMotanResponse(motanResponse core.Response, fasthttpRespon fasthttpResponse.Header.VisitAll(func(k, v []byte) { motanResponse.SetAttachment(string(k), string(v)) }) - if resp, ok := motanResponse.(*core.MotanResponse); ok { + if resp, ok := motanResponse.(*HttpMotanResponse); ok { httpResponseBody := fasthttpResponse.Body() - if httpResponseBody != nil { - motanResponseBody := make([]byte, len(httpResponseBody)) - copy(motanResponseBody, httpResponseBody) - resp.Value = motanResponseBody - } + // direct assign to resp.Value + resp.Value = httpResponseBody + // record this http response, release it when releasing HttpMotanResponse + resp.HttpResponse = fasthttpResponse } } diff --git a/lb/randomLb.go b/lb/randomLb.go index a8869526..48ef206b 100644 --- a/lb/randomLb.go +++ b/lb/randomLb.go @@ -13,11 +13,13 @@ type RandomLB struct { func (r *RandomLB) OnRefresh(endpoints []motan.EndPoint) { r.endpoints = endpoints } + func (r *RandomLB) Select(request motan.Request) motan.EndPoint { eps := r.endpoints _, endpoint := SelectOneAtRandom(eps) return endpoint } + func (r *RandomLB) SelectArray(request motan.Request) []motan.EndPoint { eps := r.endpoints index, endpoint := SelectOneAtRandom(eps) diff --git a/log/bytes.go b/log/bytes.go new file mode 100644 index 00000000..f9582fc0 --- /dev/null +++ b/log/bytes.go @@ -0,0 +1,67 @@ +package vlog + +import ( + "strconv" + "sync" +) + +var ( + initSize = 256 + innerBytesBufferPool = sync.Pool{New: func() interface{} { + return &innerBytesBuffer{buf: make([]byte, 0, initSize)} + }} +) + +// innerBytesBuffer is a variable-sized buffer of bytes with Write methods. +type innerBytesBuffer struct { + buf []byte // reuse +} + +// WriteString write a str string append the innerBytesBuffer +func (b *innerBytesBuffer) WriteString(str string) { + b.buf = append(b.buf, str...) +} + +// WriteBoolString append the string value of v(true/false) to innerBytesBuffer +func (b *innerBytesBuffer) WriteBoolString(v bool) { + if v { + b.WriteString("true") + } else { + b.WriteString("false") + } +} + +// WriteUint64String append the string value of u to innerBytesBuffer +func (b *innerBytesBuffer) WriteUint64String(u uint64) { + b.buf = strconv.AppendUint(b.buf, u, 10) +} + +// WriteInt64String append the string value of i to innerBytesBuffer +func (b *innerBytesBuffer) WriteInt64String(i int64) { + b.buf = strconv.AppendInt(b.buf, i, 10) +} + +func (b *innerBytesBuffer) Bytes() []byte { return b.buf } + +func (b *innerBytesBuffer) String() string { + return string(b.buf) +} + +func (b *innerBytesBuffer) Reset() { + b.buf = b.buf[:0] +} + +func (b *innerBytesBuffer) Len() int { return len(b.buf) } + +func (b *innerBytesBuffer) Cap() int { return cap(b.buf) } + +func acquireBytesBuffer() *innerBytesBuffer { + return innerBytesBufferPool.Get().(*innerBytesBuffer) +} + +func releaseBytesBuffer(b *innerBytesBuffer) { + if b != nil { + b.Reset() + innerBytesBufferPool.Put(b) + } +} diff --git a/log/bytes_test.go b/log/bytes_test.go new file mode 100644 index 00000000..333c0ece --- /dev/null +++ b/log/bytes_test.go @@ -0,0 +1,104 @@ +package vlog + +import ( + "strconv" + "testing" +) + +func TestWrite(t *testing.T) { + // new BytesBuffer + buf := acquireBytesBuffer() + if buf.Len() != 0 { + t.Errorf("new buf length not zero.") + } + if buf.Cap() != initSize { + t.Errorf("buf cap not correct.real:%d, expect:%d\n", buf.Cap(), initSize) + } + + // write string + buf.Reset() + buf.WriteString("string1") + buf.WriteString("string2") + tempbytes := buf.Bytes() + if "string1" != string(tempbytes[:7]) { + t.Errorf("write string not correct.buf:%+v\n", buf) + } + if "string2" != string(tempbytes[7:14]) { + t.Errorf("write string not correct.buf:%+v\n", buf) + } + + // write bool string + buf.Reset() + buf.WriteBoolString(true) + buf.WriteBoolString(false) + tempbytes = buf.Bytes() + if "true" != string(tempbytes[:4]) { + t.Errorf("write bool string not correct.buf:%+v\n", buf) + } + if "false" != string(tempbytes[4:9]) { + t.Errorf("write bool string not correct.buf:%+v\n", buf) + } + + // write uint64 string + buf.Reset() + var u1 uint64 = 11111111 + var u2 uint64 = 22222222 + buf.WriteUint64String(u1) + buf.WriteUint64String(u2) + tempbytes = buf.Bytes() + if "11111111" != string(tempbytes[:8]) { + t.Errorf("write unit64 string not correct.buf:%+v\n", buf) + } + if "22222222" != string(tempbytes[8:]) { + t.Errorf("write uint64 string not correct.buf:%+v\n", buf) + } + + // write int64 string + buf.Reset() + var i1 int64 = 11111111 + var i2 int64 = -22222222 + buf.WriteInt64String(i1) + buf.WriteInt64String(i2) + tempbytes = buf.Bytes() + if "11111111" != string(tempbytes[:8]) { + t.Errorf("write unit64 string not correct.buf:%+v\n", buf) + } + if "-22222222" != string(tempbytes[8:]) { + t.Errorf("write uint64 string not correct.buf:%+v\n", buf) + } +} + +func TestRead(t *testing.T) { + buf := acquireBytesBuffer() + string1 := "aaaaaaaaaaaa" + buf.WriteString(string1) + buf.WriteUint64String(uint64(len(string1))) + buf.WriteBoolString(false) + buf.WriteInt64String(int64(-len(string1))) + + string2 := "bbbbbbbbbbbb" + buf.WriteString(string2) + buf.WriteUint64String(uint64(len(string2))) + buf.WriteBoolString(true) + buf.WriteInt64String(int64(-len(string2))) + + data := buf.Bytes() + buf2 := &innerBytesBuffer{data} + rsize := len(string1) + 2 + 5 + 3 + len(string2) + 2 + 4 + 3 + if buf2.Len() != rsize { + t.Errorf("read buf len not correct. buf:%v\n", buf2) + } + + // read value + expectValue := string1 + + strconv.Itoa(len(string1)) + + "false" + + "-" + strconv.Itoa(len(string1)) + + string2 + + strconv.Itoa(len(string2)) + + "true" + + "-" + strconv.Itoa(len(string2)) + if expectValue != buf2.String() { + t.Errorf("read value not correct. buf:%v\n", buf2) + } +} diff --git a/log/log.go b/log/log.go index d1d484bb..a3b71c2e 100644 --- a/log/log.go +++ b/log/log.go @@ -1,14 +1,12 @@ package vlog import ( - "bytes" "flag" "github.com/weibocom/motan-go/metrics/sampler" "log" "os" "path/filepath" "runtime/debug" - "strconv" "sync" "sync/atomic" "time" @@ -18,6 +16,9 @@ import ( ) var ( + accessLogEntityPool = sync.Pool{New: func() interface{} { + return new(AccessLogEntity) + }} loggerInstance Logger once sync.Once logDir = flag.String("log_dir", ".", "If non-empty, write log files in this directory") @@ -390,12 +391,13 @@ func (d *defaultLogger) doAccessLog(logObject *AccessLogEntity) { zap.String("exception", logObject.Exception), zap.String("upstreamCode", logObject.UpstreamCode)) } else { - var buffer bytes.Buffer + buffer := acquireBytesBuffer() + buffer.WriteString(logObject.FilterName) buffer.WriteString("|") buffer.WriteString(logObject.Role) buffer.WriteString("|") - buffer.WriteString(strconv.FormatUint(logObject.RequestID, 10)) + buffer.WriteUint64String(logObject.RequestID) buffer.WriteString("|") buffer.WriteString(logObject.Service) buffer.WriteString("|") @@ -405,15 +407,15 @@ func (d *defaultLogger) doAccessLog(logObject *AccessLogEntity) { buffer.WriteString("|") buffer.WriteString(logObject.RemoteAddress) buffer.WriteString("|") - buffer.WriteString(strconv.Itoa(logObject.ReqSize)) + buffer.WriteInt64String(int64(logObject.ReqSize)) buffer.WriteString("|") - buffer.WriteString(strconv.Itoa(logObject.ResSize)) + buffer.WriteInt64String(int64(logObject.ResSize)) buffer.WriteString("|") - buffer.WriteString(strconv.FormatInt(logObject.BizTime, 10)) + buffer.WriteInt64String(logObject.BizTime) buffer.WriteString("|") - buffer.WriteString(strconv.FormatInt(logObject.TotalTime, 10)) + buffer.WriteInt64String(logObject.TotalTime) buffer.WriteString("|") - buffer.WriteString(strconv.FormatBool(logObject.Success)) + buffer.WriteBoolString(logObject.Success) buffer.WriteString("|") buffer.WriteString(logObject.ResponseCode) buffer.WriteString("|") @@ -421,7 +423,10 @@ func (d *defaultLogger) doAccessLog(logObject *AccessLogEntity) { buffer.WriteString("|") buffer.WriteString(logObject.UpstreamCode) d.accessLogger.Info(buffer.String()) + + releaseBytesBuffer(buffer) } + ReleaseAccessLogEntity(logObject) } func (d *defaultLogger) MetricsLog(msg string) { @@ -490,3 +495,13 @@ func (d *defaultLogger) SetMetricsLogAvailable(status bool) { d.metricsLevel.SetLevel(zapcore.Level(defaultLogLevel + 1)) } } + +func AcquireAccessLogEntity() *AccessLogEntity { + return accessLogEntityPool.Get().(*AccessLogEntity) +} + +func ReleaseAccessLogEntity(entity *AccessLogEntity) { + if entity != nil { + accessLogEntityPool.Put(entity) + } +} diff --git a/log/log_test.go b/log/log_test.go index fe025fbc..b3552bec 100644 --- a/log/log_test.go +++ b/log/log_test.go @@ -31,7 +31,7 @@ func init() { Exception: "Exception"} } -//BenchmarkLogSprintf: 736 ns/op +// BenchmarkLogSprintf: 736 ns/op func BenchmarkLogSprintf(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { @@ -52,7 +52,7 @@ func BenchmarkLogSprintf(b *testing.B) { } } -//BenchmarkLogBufferWritePlus: 438 ns/op +// BenchmarkLogBufferWritePlus: 438 ns/op func BenchmarkLogBufferWritePlus(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { @@ -74,7 +74,7 @@ func BenchmarkLogBufferWritePlus(b *testing.B) { } } -//BenchmarkLogBufferWrite: 406 ns/op +// BenchmarkLogBufferWrite: 406 ns/op func BenchmarkLogBufferWrite(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { diff --git a/log/rotate.go b/log/rotate.go index 18d7ca7b..652027bb 100644 --- a/log/rotate.go +++ b/log/rotate.go @@ -1,10 +1,10 @@ -//********** this file is modified based on "gopkg.in/natefinch/lumberjack.v2" ********** +// Package vlog ********** this file is modified based on "gopkg.in/natefinch/lumberjack.v2" ********** // Package lumberjack provides a rolling logger. // // Note that this is v2.0 of lumberjack, and should be imported using gopkg.in // thusly: // -// import "gopkg.in/natefinch/lumberjack.v2" +// import "gopkg.in/natefinch/lumberjack.v2" // // The package name remains simply lumberjack, and the code resides at // https://github.com/natefinch/lumberjack under the v2.0 branch. @@ -69,7 +69,7 @@ var _ io.WriteCloser = (*RotateWriter)(nil) // `/var/log/foo/server.log`, a backup created at 6:30pm on Nov 11 2016 would // use the filename `/var/log/foo/server-2016-11-04T18-30-00.000.log` // -// Cleaning Up Old Log Files +// # Cleaning Up Old Log Files // // Whenever a new logfile gets created, old log files may be deleted. The most // recent files according to the encoded timestamp will be retained, up to a diff --git a/manageHandler.go b/manageHandler.go index 98fd4208..f8e79085 100644 --- a/manageHandler.go +++ b/manageHandler.go @@ -158,7 +158,7 @@ func (s *StatusHandler) getStatus() []byte { exporter := v.(motan.Exporter) group := exporter.GetURL().Group service := exporter.GetURL().Path - statItem := metrics.GetStatItem(metrics.Escape(group), metrics.Escape(service)) + statItem := metrics.GetStatItem(group, "", service) if statItem == nil { return true } diff --git a/manageHandler_test.go b/manageHandler_test.go index 55bafec9..10655465 100644 --- a/manageHandler_test.go +++ b/manageHandler_test.go @@ -20,7 +20,7 @@ func TestGetAllService(t *testing.T) { i := InfoHandler{} ctx := &motan.Context{ ServiceURLs: map[string]*motan.URL{ - "test": &motan.URL{ + "test": { Group: "testgroup", Path: "testpath", }, diff --git a/metrics/graphite.go b/metrics/graphite.go index 93f80cb5..45324687 100644 --- a/metrics/graphite.go +++ b/metrics/graphite.go @@ -32,7 +32,7 @@ type graphite struct { Port int Name string localIP string - lock sync.Mutex + lock *sync.Mutex //using pointer avoid shadow copy, cause lock issue conn net.Conn } @@ -50,7 +50,7 @@ func newGraphite(ip, pool string, port int) *graphite { Host: ip, Port: port, Name: pool, - lock: sync.Mutex{}, + lock: &sync.Mutex{}, conn: getUDPConn(ip, port), } } @@ -102,19 +102,21 @@ func GenGraphiteMessages(localIP string, snapshots []Snapshot) []string { if len(pni) < minKeyLength { return } + escapedService := snap.GetEscapedService() + escapedGroup := snap.GetEscapedGroup() + snap.GetGroupSuffix() if snap.IsHistogram(k) { //histogram for slaK, slaV := range sla { segment += fmt.Sprintf("%s.%s.%s.byhost.%s.%s.%s.%s:%.2f|kv\n", - pni[0], pni[1], snap.GetGroup(), localIP, snap.GetService(), pni[2], slaK, snap.Percentile(k, slaV)) + pni[0], pni[1], escapedGroup, localIP, escapedService, pni[2], slaK, snap.Percentile(k, slaV)) } segment += fmt.Sprintf("%s.%s.%s.byhost.%s.%s.%s.%s:%.2f|ms\n", - pni[0], pni[1], snap.GetGroup(), localIP, snap.GetService(), pni[2], "avg_time", snap.Mean(k)) + pni[0], pni[1], escapedGroup, localIP, escapedService, pni[2], "avg_time", snap.Mean(k)) } else if snap.IsCounter(k) { //counter segment = fmt.Sprintf("%s.%s.%s.byhost.%s.%s.%s:%d|c\n", - pni[0], pni[1], snap.GetGroup(), localIP, snap.GetService(), pni[2], snap.Count(k)) + pni[0], pni[1], escapedGroup, localIP, escapedService, pni[2], snap.Count(k)) } else { // gauge segment = fmt.Sprintf("%s.%s.%s.byhost.%s.%s.%s:%d|kv\n", - pni[0], pni[1], snap.GetGroup(), localIP, snap.GetService(), pni[2], snap.Value(k)) + pni[0], pni[1], escapedGroup, localIP, escapedService, pni[2], snap.Value(k)) } if buf.Len() > 0 && buf.Len()+len(segment) > messageMaxLen { messages = append(messages, buf.String()) diff --git a/metrics/graphite_test.go b/metrics/graphite_test.go index fceb856d..dbc3c068 100644 --- a/metrics/graphite_test.go +++ b/metrics/graphite_test.go @@ -25,7 +25,7 @@ func Test_graphite_Write(t *testing.T) { go server.start() time.Sleep(100 * time.Millisecond) g := newGraphite("127.0.0.1", "test pool", 3456) - item := NewStatItem(group, service) + item := NewStatItem(group, "", service) item.AddCounter(keyPrefix+"c1", 1) item.AddHistograms(keyPrefix+" h1", 100) err := g.Write([]Snapshot{item.SnapshotAndClear()}) @@ -48,7 +48,7 @@ func Test_graphite_Write(t *testing.T) { } func TestGenGraphiteMessages(t *testing.T) { - item1 := NewDefaultStatItem(group, service) + item1 := NewDefaultStatItem(group, "", service) // counter message item1.AddCounter(keyPrefix+"c1", 1) @@ -69,8 +69,8 @@ func TestGenGraphiteMessages(t *testing.T) { role, application, group, localhost, service, methodPrefix+"h1", "avg_time", float32(100))), "histogram message") // multi items - item2 := NewDefaultStatItem(group+"2", service+"2") - item3 := NewDefaultStatItem(group+"3", service+"3") + item2 := NewDefaultStatItem(group+"2", "", service+"2") + item3 := NewDefaultStatItem(group+"3", "", service+"3") item3.SetReport(false) // item3 not report length := 10 diff --git a/metrics/metrics.go b/metrics/metrics.go index 574917b1..614e9dc1 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -42,6 +42,7 @@ const ( ) var ( + metricsKeyBuilderBufferSize = 64 // NewStatItem is the factory func for StatItem NewStatItem = NewDefaultStatItem items = make(map[string]StatItem, 64) @@ -52,15 +53,21 @@ var ( processor: defaultEventProcessor, //sink processor size eventBus: make(chan *event, eventBufferSize), writers: make(map[string]StatWriter), - evtBuf: &sync.Pool{New: func() interface{} { return new(event) }}, + eventPool: &sync.Pool{New: func() interface{} { + return &event{} + }}, } + escapeCache sync.Map ) type StatItem interface { SetService(service string) GetService() string + GetEscapedService() string SetGroup(group string) GetGroup() string + GetEscapedGroup() string + GetGroupSuffix() string AddCounter(key string, value int64) AddHistograms(key string, duration int64) AddGauge(key string, value int64) @@ -97,42 +104,43 @@ type StatWriter interface { Write(snapshots []Snapshot) error } -func GetOrRegisterStatItem(group string, service string) StatItem { +func GetOrRegisterStatItem(group, groupSuffix string, service string) StatItem { + k := group + groupSuffix + service itemsLock.RLock() - item := items[group+service] + item := items[k] itemsLock.RUnlock() if item != nil { return item } itemsLock.Lock() - item = items[group+service] + item = items[k] if item == nil { - item = NewStatItem(group, service) - items[group+service] = item + item = NewStatItem(group, groupSuffix, service) + items[k] = item } itemsLock.Unlock() return item } -func GetStatItem(group string, service string) StatItem { +func GetStatItem(group, groupSuffix string, service string) StatItem { itemsLock.RLock() defer itemsLock.RUnlock() - return items[group+service] + return items[group+groupSuffix+service] } // NewDefaultStatItem create a new statistic item, you should escape input parameter before call this function -func NewDefaultStatItem(group string, service string) StatItem { - return &DefaultStatItem{group: group, service: service, holder: &RegistryHolder{registry: metrics.NewRegistry()}, isReport: true} +func NewDefaultStatItem(group, groupSuffix string, service string) StatItem { + return &DefaultStatItem{group: group, groupSuffix: groupSuffix, service: service, holder: &RegistryHolder{registry: metrics.NewRegistry()}, isReport: true} } -func RMStatItem(group string, service string) { +func RMStatItem(group, groupSuffix string, service string) { itemsLock.RLock() - i := items[group+service] + i := items[group+groupSuffix+service] itemsLock.RUnlock() if i != nil { i.Clear() itemsLock.Lock() - delete(items, group+service) + delete(items, group+groupSuffix+service) itemsLock.Unlock() } } @@ -165,14 +173,19 @@ func StatItemSize() int { return len(items) } +// Escape the string avoid invalid graphite key func Escape(s string) string { - return strings.Map(func(char rune) rune { + if v, ok := escapeCache.Load(s); ok { + return v.(string) + } + v := strings.Map(func(char rune) rune { if (char >= 'a' && char <= 'z') || (char >= 'A' && char <= 'Z') || (char >= '0' && char <= '9') || (char == '-') { return char - } else { - return '_' } + return '_' }, s) + escapeCache.Store(s, v) + return v } func AddCounter(group string, service string, key string, value int64) { @@ -187,13 +200,29 @@ func AddGauge(group string, service string, key string, value int64) { sendEvent(eventGauge, group, service, key, value) } +// AddCounterWithKeys arguments: group & groupSuffix & service & keys elements & keySuffix is text without escaped +func AddCounterWithKeys(group, groupSuffix string, service string, keys []string, keySuffix string, value int64) { + sendEventWithKeys(eventCounter, group, groupSuffix, service, keys, keySuffix, value) +} + +// AddHistogramsWithKeys arguments: group & groupSuffix & service & keys elements & keySuffix is text without escaped +func AddHistogramsWithKeys(group, groupSuffix string, service string, keys []string, suffix string, duration int64) { + sendEventWithKeys(eventHistograms, group, groupSuffix, service, keys, suffix, duration) +} + func sendEvent(eventType int32, group string, service string, key string, value int64) { - evt := rp.evtBuf.Get().(*event) + sendEventWithKeys(eventType, group, "", service, []string{key}, "", value) +} + +func sendEventWithKeys(eventType int32, group, groupSuffix string, service string, keys []string, suffix string, value int64) { + evt := rp.eventPool.Get().(*event) evt.event = eventType - evt.key = key + evt.keys = keys evt.group = group evt.service = service evt.value = value + evt.keySuffix = suffix + evt.groupSuffix = groupSuffix select { case rp.eventBus <- evt: default: @@ -239,11 +268,43 @@ func startSampleStatus(application string) { } type event struct { - event int32 - key string - group string - service string - value int64 + event int32 + keys []string + keySuffix string + group string + groupSuffix string + service string + value int64 +} + +// reset used to reset the event object before put it back +// to event objects pool +func (s *event) reset() { + s.event = 0 + s.keys = s.keys[:0] + s.keySuffix = "" + s.group = "" + s.service = "" + s.value = 0 + s.groupSuffix = "" +} + +// getMetricKey get the metrics key when add metrics data into metrics object, +// the key split by : used to when send data to graphite +func (s *event) getMetricKey() string { + keyBuilder := motan.AcquireBytesBuffer(metricsKeyBuilderBufferSize) + defer motan.ReleaseBytesBuffer(keyBuilder) + l := len(s.keys) + for idx, k := range s.keys { + keyBuilder.WriteString(Escape(k)) + if idx < l-1 { + keyBuilder.WriteString(":") + } + } + if s.keySuffix != "" { + keyBuilder.WriteString(s.keySuffix) + } + return string(keyBuilder.Bytes()) } type RegistryHolder struct { @@ -252,6 +313,7 @@ type RegistryHolder struct { type DefaultStatItem struct { group string + groupSuffix string service string holder *RegistryHolder isReport bool @@ -271,6 +333,11 @@ func (d *DefaultStatItem) GetService() string { return d.service } +// GetEscapedService return the escaped service used as graphite key +func (d *DefaultStatItem) GetEscapedService() string { + return Escape(d.service) +} + func (d *DefaultStatItem) SetGroup(group string) { d.group = group } @@ -279,6 +346,15 @@ func (d *DefaultStatItem) GetGroup() string { return d.group } +func (d *DefaultStatItem) GetGroupSuffix() string { + return d.groupSuffix +} + +// GetEscapedGroup return the escaped group used as graphite key +func (d *DefaultStatItem) GetEscapedGroup() string { + return Escape(d.group) +} + func (d *DefaultStatItem) AddCounter(key string, value int64) { c := d.getRegistry().Get(key) if c == nil { @@ -315,6 +391,7 @@ func (d *DefaultStatItem) SnapshotAndClear() Snapshot { old := atomic.SwapPointer((*unsafe.Pointer)(unsafe.Pointer(&d.holder)), unsafe.Pointer(&RegistryHolder{registry: metrics.NewRegistry()})) d.lastSnapshot = &ReadonlyStatItem{ group: d.group, + groupSuffix: d.groupSuffix, service: d.service, holder: (*RegistryHolder)(old), isReport: d.isReport, @@ -324,8 +401,8 @@ func (d *DefaultStatItem) SnapshotAndClear() Snapshot { return d.lastSnapshot } +// SnapshotAndClearV0 Using SnapshotAndClear instead. // Deprecated. -// Using SnapshotAndClear instead. // Because of Snapshot(DefaultStatItem) calculates metrics will call locker to do that, // cause low performance func (d *DefaultStatItem) SnapshotAndClearV0() Snapshot { @@ -467,6 +544,7 @@ func (d *DefaultStatItem) IsGauge(key string) bool { type ReadonlyStatItem struct { group string + groupSuffix string service string holder *RegistryHolder isReport bool @@ -477,6 +555,7 @@ type ReadonlyStatItem struct { func (d *ReadonlyStatItem) getRegistry() metrics.Registry { return (*RegistryHolder)(atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.holder)))).registry } + func (d *ReadonlyStatItem) getCache(key string) interface{} { d.buildCacheLock.RLock() if v, ok := d.cache[key]; ok { @@ -521,6 +600,11 @@ func (d *ReadonlyStatItem) GetService() string { return d.service } +// GetEscapedService return the escaped service used as graphite key +func (d *ReadonlyStatItem) GetEscapedService() string { + return Escape(d.service) +} + func (d *ReadonlyStatItem) SetGroup(group string) { panic("action not supported") } @@ -529,6 +613,15 @@ func (d *ReadonlyStatItem) GetGroup() string { return d.group } +func (d *ReadonlyStatItem) GetGroupSuffix() string { + return d.groupSuffix +} + +// GetEscapedGroup return the escaped group used as graphite key +func (d *ReadonlyStatItem) GetEscapedGroup() string { + return Escape(d.group) +} + func (d *ReadonlyStatItem) AddCounter(key string, value int64) { panic("action not supported") } @@ -725,13 +818,16 @@ type reporter struct { interval time.Duration processor int writers map[string]StatWriter - evtBuf *sync.Pool + eventPool *sync.Pool writersLock sync.RWMutex } func (r *reporter) eventLoop() { for evt := range r.eventBus { r.processEvent(evt) + // clean the event object before put it back + evt.reset() + r.eventPool.Put(evt) } } @@ -746,14 +842,15 @@ func (r *reporter) addWriter(key string, sw StatWriter) { func (r *reporter) processEvent(evt *event) { defer motan.HandlePanic(nil) - item := GetOrRegisterStatItem(evt.group, evt.service) + item := GetOrRegisterStatItem(evt.group, evt.groupSuffix, evt.service) + key := evt.getMetricKey() switch evt.event { case eventCounter: - item.AddCounter(evt.key, evt.value) + item.AddCounter(key, evt.value) case eventHistograms: - item.AddHistograms(evt.key, evt.value) + item.AddHistograms(key, evt.value) case eventGauge: - item.AddGauge(evt.key, evt.value) + item.AddGauge(key, evt.value) } } diff --git a/metrics/metrics_test.go b/metrics/metrics_test.go index 604f4359..612640b6 100644 --- a/metrics/metrics_test.go +++ b/metrics/metrics_test.go @@ -19,19 +19,19 @@ var ( func TestGetStatItem(t *testing.T) { ClearStatItems() // get - si := GetStatItem(group, service) + si := GetStatItem(group, "", service) assert.Nil(t, si, "GetStatItem") assert.Equal(t, 0, StatItemSize(), "item size") // register - si = GetOrRegisterStatItem(group, service) + si = GetOrRegisterStatItem(group, "", service) assert.NotNil(t, si, "GetOrRegisterStatItem") assert.Equal(t, group, si.GetGroup(), "StatItem group") assert.Equal(t, service, si.GetService(), "StatItem service") assert.Equal(t, 1, StatItemSize(), "item size") - again := GetStatItem(group, service) + again := GetStatItem(group, "", service) assert.True(t, si == again, "get StatItem not the same one") - si2 := GetOrRegisterStatItem(group+"2", service+"2") + si2 := GetOrRegisterStatItem(group+"2", "", service+"2") assert.NotNil(t, si2, "GetOrRegisterStatItem") assert.Equal(t, group+"2", si2.GetGroup(), "StatItem group") assert.Equal(t, service+"2", si2.GetService(), "StatItem service") @@ -39,17 +39,17 @@ func TestGetStatItem(t *testing.T) { assert.Equal(t, 2, StatItemSize(), "item size") // rm - RMStatItem(group, service) - si = GetStatItem(group, service) + RMStatItem(group, "", service) + si = GetStatItem(group, "", service) assert.Nil(t, si, "GetStatItem") - si2 = GetStatItem(group+"2", service+"2") + si2 = GetStatItem(group+"2", "", service+"2") assert.NotNil(t, si2, "GetOrRegisterStatItem") // clear ClearStatItems() - si = GetStatItem(group, service) + si = GetStatItem(group, "", service) assert.Nil(t, si, "clear not work") - si2 = GetStatItem(group+"2", service+"2") + si2 = GetStatItem(group+"2", "", service+"2") assert.Nil(t, si2, "clear not work") // multi thread @@ -59,7 +59,7 @@ func TestGetStatItem(t *testing.T) { for i := 0; i < size; i++ { j := i go func() { - s := GetOrRegisterStatItem(group, service) + s := GetOrRegisterStatItem(group, "", service) lock.Lock() sia[j] = s lock.Unlock() @@ -75,14 +75,14 @@ func TestGetStatItem(t *testing.T) { } func TestNewDefaultStatItem(t *testing.T) { - si := NewStatItem(group, service) + si := NewStatItem(group, "", service) assert.NotNil(t, si, "NewStatItem") _, ok := si.(StatItem) assert.True(t, ok, "type not StatItem") } func TestDefaultStatItem(t *testing.T) { - item := NewDefaultStatItem(group, service) + item := NewDefaultStatItem(group, "", service) assert.NotNil(t, item, "GetOrRegisterStatItem") assert.Equal(t, group, item.GetGroup(), "StatItem group") assert.Equal(t, service, item.GetService(), "StatItem service") @@ -161,7 +161,7 @@ func TestStat(t *testing.T) { AddGauge(group, service, "g2", 200) } time.Sleep(100 * time.Millisecond) - item := GetStatItem(group, service) + item := GetStatItem(group, "", service) assert.NotNil(t, item, "item not exist") snap := item.SnapshotAndClear() // test counters @@ -236,7 +236,7 @@ func (m *mockWriter) GetSnapshot() []Snapshot { func TestReadonlyStatItem(t *testing.T) { assert := assert.New(t) - item := NewDefaultStatItem(group, service) + item := NewDefaultStatItem(group, "", service) item.SetService(service) item.SetGroup(group) roItem := item.SnapshotAndClear() @@ -281,7 +281,7 @@ func TestReadonlyStatItem(t *testing.T) { } func TestDefaultStatItem_1(t *testing.T) { - item := NewDefaultStatItem(group, service) + item := NewDefaultStatItem(group, "", service) assert.Equal(t, "str", Escape("str")) assert.NotNil(t, item, "GetOrRegisterStatItem") assert.Equal(t, group, item.GetGroup(), "StatItem group") diff --git a/protocol/motan1Protocol.go b/protocol/motan1Protocol.go index 136c525f..19d9004c 100644 --- a/protocol/motan1Protocol.go +++ b/protocol/motan1Protocol.go @@ -60,6 +60,11 @@ const ( HEARTBEAT_RESPONSE_STRING = HEARTBEAT_METHOD_NAME ) +const ( + V1Group = "group" + V1Version = "version" +) + const MAX_BLOCK_SIZE = 1024 // base binary arrays diff --git a/protocol/motanProtocol.go b/protocol/motanProtocol.go index 0e1e413e..8b102c16 100644 --- a/protocol/motanProtocol.go +++ b/protocol/motanProtocol.go @@ -11,6 +11,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" motan "github.com/weibocom/motan-go/core" @@ -20,9 +21,10 @@ import ( const ( DefaultMetaSize = 16 DefaultMaxContentLength = 10 * 1024 * 1024 + DefaultBufferSize = 256 ) -//message type +// message type const ( Req = iota Res @@ -32,6 +34,7 @@ const ( const ( MotanMagic = 0xf1f1 HeaderLength = 13 + MetaStartOffset = HeaderLength + 4 Version1 = 0 Version2 = 1 defaultProtocol = "motan2" @@ -40,7 +43,7 @@ const ( const ( MPath = "M_p" MMethod = "M_m" - MExceptionn = "M_e" + MException = "M_e" MProcessTime = "M_pt" MMethodDesc = "M_md" MGroup = "M_g" @@ -60,6 +63,14 @@ type Header struct { RequestID uint64 } +func (h *Header) Reset() { + h.Magic = 0 + h.MsgType = 0 + h.VersionStatus = 0 + h.Serialize = 0 + h.RequestID = 0 +} + func (h *Header) Clone() *Header { return &Header{ Magic: h.Magic, @@ -71,13 +82,16 @@ func (h *Header) Clone() *Header { } type Message struct { - Header *Header - Metadata *motan.StringMap - Body []byte - Type int + Header *Header + Metadata *motan.StringMap + Body []byte + Type int + canRelease atomic.Value + // lazy init or reset when call Message.Encode or Message.Encode0 + bytesBuffer *motan.BytesBuffer } -//serialize +// serialize const ( Hessian = iota GrpcPb @@ -104,6 +118,9 @@ var ( writeBufPool = &sync.Pool{New: func() interface{} { // for gzip write buffer return &bytes.Buffer{} }} + messagePool = sync.Pool{New: func() interface{} { + return &Message{Metadata: motan.NewStringMap(DefaultMetaSize), Header: &Header{}} + }} ) // errors @@ -248,16 +265,34 @@ func BuildHeader(msgType int, proxy bool, serialize int, requestID uint64, msgSt } //status - status := uint8(0x08 | (uint8(msgStatus) & 0x07)) + status := 0x08 | (uint8(msgStatus) & 0x07) - serial := uint8(0x00 | (uint8(serialize) << 3)) + serial := 0x00 | (uint8(serialize) << 3) header := &Header{MotanMagic, mtype, status, serial, requestID} return header } -func (msg *Message) Encode() (buf *motan.BytesBuffer) { - metabuf := motan.NewBytesBuffer(256) +// Encode0 encode header、meta、body size +// used with GetEncodedBytes method +// unexpected if call Message.GetEncodedBytes after rewrite result(*motan.BytesBuffer) +func (msg *Message) Encode0() { + if msg.bytesBuffer == nil { + msg.bytesBuffer = motan.NewBytesBuffer(DefaultBufferSize) + } else { + msg.bytesBuffer.Reset() + } + + // encode header. + msg.bytesBuffer.WriteUint16(MotanMagic) + msg.bytesBuffer.WriteByte(msg.Header.MsgType) + msg.bytesBuffer.WriteByte(msg.Header.VersionStatus) + msg.bytesBuffer.WriteByte(msg.Header.Serialize) + msg.bytesBuffer.WriteUint64(msg.Header.RequestID) + + // 4 byte for meta size + msg.bytesBuffer.SetWPos(MetaStartOffset) + // encode meta msg.Metadata.Range(func(k, v string) bool { if k == "" || v == "" { return true @@ -266,38 +301,60 @@ func (msg *Message) Encode() (buf *motan.BytesBuffer) { vlog.Errorf("metadata not correct.k:%s, v:%s", k, v) return true } - metabuf.Write([]byte(k)) - metabuf.WriteByte('\n') - metabuf.Write([]byte(v)) - metabuf.WriteByte('\n') + msg.bytesBuffer.WriteString(k) + msg.bytesBuffer.WriteByte('\n') + msg.bytesBuffer.WriteString(v) + msg.bytesBuffer.WriteByte('\n') return true }) + metaWpos := msg.bytesBuffer.GetWPos() + metaSize := 0 + if metaWpos != MetaStartOffset { + // rewrite meta last char '\n' + metaWpos -= 1 + metaSize = metaWpos - MetaStartOffset - if metabuf.Len() > 0 { - metabuf.SetWPos(metabuf.GetWPos() - 1) } - metasize := metabuf.Len() - bodysize := len(msg.Body) - buf = motan.NewBytesBuffer(int(HeaderLength + bodysize + metasize + 8)) - // encode header. - buf.WriteUint16(MotanMagic) - buf.WriteByte(msg.Header.MsgType) - buf.WriteByte(msg.Header.VersionStatus) - buf.WriteByte(msg.Header.Serialize) - buf.WriteUint64(msg.Header.RequestID) + msg.bytesBuffer.SetWPos(HeaderLength) + msg.bytesBuffer.WriteUint32(uint32(metaSize)) - // encode meta - buf.WriteUint32(uint32(metasize)) - if metasize > 0 { - buf.Write(metabuf.Bytes()) + msg.bytesBuffer.SetWPos(metaWpos) + // encode body size + bodySize := len(msg.Body) + msg.bytesBuffer.WriteUint32(uint32(bodySize)) +} + +// GetEncodedBytes get encoded headers Message.Body +// used with Encode0 method +// unexpected if repeat call Message.GetEncodedBytes after rewrite result([][]byte) +func (msg *Message) GetEncodedBytes() [][]byte { + if msg.bytesBuffer == nil { // maybe not call Message.Encode or Message.Encode0 + msg.Encode0() } + return [][]byte{msg.bytesBuffer.Bytes(), msg.Body} +} + +// Encode result encoded result +// unexpected if call Message.GetEncodedBytes after rewrite result(*motan.BytesBuffer) +func (msg *Message) Encode() *motan.BytesBuffer { + msg.Encode0() + bodySize := len(msg.Body) + if bodySize > 0 { + msg.bytesBuffer.Write(msg.Body) + } + return msg.bytesBuffer +} - // encode body - buf.WriteUint32(uint32(bodysize)) - if bodysize > 0 { - buf.Write(msg.Body) +func (msg *Message) Reset() bool { + if ok, v := msg.canRelease.Load().(bool); ok && v { + msg.Type = 0 + msg.Body = msg.Body[:0] + msg.Header.Reset() + msg.Metadata.Reset() + msg.canRelease.Store(false) + return true } - return buf + return false } func (msg *Message) Clone() interface{} { @@ -312,6 +369,10 @@ func (msg *Message) Clone() interface{} { return newMessage } +func (msg *Message) SetCanRelease() { + msg.canRelease.Store(true) +} + func CheckMotanVersion(buf *bufio.Reader) (version int, err error) { var b []byte b, err = buf.Peek(4) @@ -326,93 +387,111 @@ func CheckMotanVersion(buf *bufio.Reader) (version int, err error) { return int(b[3] >> 3 & 0x1f), nil } -func Decode(buf *bufio.Reader) (msg *Message, err error) { - msg, _, err = DecodeWithTime(buf, motan.DefaultMaxContentLength) +func Decode(reader *bufio.Reader, buf *[]byte) (msg *Message, err error) { + msg, _, err = DecodeWithTime(reader, buf, motan.DefaultMaxContentLength) return msg, err } -func DecodeWithTime(buf *bufio.Reader, maxContentLength int) (msg *Message, start time.Time, err error) { - temp := make([]byte, HeaderLength, HeaderLength) - +// DecodeWithTime the parameter buf is a slice pointer, so the +// extension of *buf will affect the slice outside in order to +// reuse the buf which is used by reading header info from reader +func DecodeWithTime(reader *bufio.Reader, buf *[]byte, maxContentLength int) (msg *Message, start time.Time, err error) { + start = time.Now() // record time when starting to reader data + readSlice := *buf // decode header - _, err = io.ReadAtLeast(buf, temp, HeaderLength) - start = time.Now() // record time when starting to read data + _, err = io.ReadAtLeast(reader, readSlice[:HeaderLength], HeaderLength) if err != nil { return nil, start, err } - mn := binary.BigEndian.Uint16(temp[:2]) // TODO 不再验证 + mn := binary.BigEndian.Uint16(readSlice[:2]) // TODO 不再验证 if mn != MotanMagic { vlog.Errorf("wrong magic num:%d, err:%v", mn, err) return nil, start, ErrMagicNum } - - header := &Header{Magic: MotanMagic} - header.MsgType = temp[2] - header.VersionStatus = temp[3] - version := header.GetVersion() + // get a message from pool + msg = AcquireMessage() + defer func() { + if err != nil { + msg.SetCanRelease() + ReleaseMessage(msg) + } + }() + msg.Header.Magic = MotanMagic + msg.Header.MsgType = readSlice[2] + msg.Header.VersionStatus = readSlice[3] + version := msg.Header.GetVersion() if version != Version2 { // TODO 不再验证 + err = ErrVersion vlog.Errorf("unsupported protocol version number: %d", version) return nil, start, ErrVersion } - header.Serialize = temp[4] - header.RequestID = binary.BigEndian.Uint64(temp[5:]) + msg.Header.Serialize = readSlice[4] + msg.Header.RequestID = binary.BigEndian.Uint64(readSlice[5:HeaderLength]) // decode meta - _, err = io.ReadAtLeast(buf, temp[:4], 4) + _, err = io.ReadAtLeast(reader, readSlice[:4], 4) if err != nil { return nil, start, err } - metasize := int(binary.BigEndian.Uint32(temp[:4])) + metasize := int(binary.BigEndian.Uint32(readSlice[:4])) if metasize > maxContentLength { + err = ErrOverSize vlog.Errorf("meta over size. meta size:%d, max size:%d", metasize, maxContentLength) return nil, start, ErrOverSize } - metamap := motan.NewStringMap(DefaultMetaSize) if metasize > 0 { - metadata, err := readBytes(buf, metasize) + if cap(readSlice) < metasize { + readSlice = make([]byte, metasize) + *buf = readSlice + } + err = readBytes(reader, readSlice, metasize) if err != nil { return nil, start, err } s, e := 0, 0 var k string for i := 0; i <= metasize; i++ { - if i == metasize || metadata[i] == '\n' { + if i == metasize || readSlice[i] == '\n' { e = i if k == "" { - k = string(metadata[s:e]) + k = string(readSlice[s:e]) } else { - metamap.Store(k, string(metadata[s:e])) + msg.Metadata.Store(k, string(readSlice[s:e])) k = "" } s = i + 1 } } if k != "" { - vlog.Errorf("decode message fail, metadata not paired. header:%v, meta:%s", header, metadata) + err = ErrMetadata + vlog.Errorf("decode message fail, metadata not paired. header:%v, meta:%s", msg.Header, readSlice) return nil, start, ErrMetadata } } //decode body - _, err = io.ReadAtLeast(buf, temp[:4], 4) + _, err = io.ReadAtLeast(reader, readSlice[:4], 4) if err != nil { return nil, start, err } - bodysize := int(binary.BigEndian.Uint32(temp[:4])) - if bodysize > maxContentLength { - vlog.Errorf("body over size. body size:%d, max size:%d", bodysize, maxContentLength) + bodySize := int(binary.BigEndian.Uint32(readSlice[:4])) + if bodySize > maxContentLength { + err = ErrOverSize + vlog.Errorf("body over size. body size:%d, max size:%d", bodySize, maxContentLength) return nil, start, ErrOverSize } - var body []byte - if bodysize > 0 { - body, err = readBytes(buf, bodysize) + if bodySize > 0 { + if cap(msg.Body) < bodySize { + msg.Body = make([]byte, bodySize) + } + msg.Body = msg.Body[:bodySize] + err = readBytes(reader, msg.Body, bodySize) } else { - body = make([]byte, 0) + msg.Body = make([]byte, 0) } if err != nil { return nil, start, err } - msg = &Message{header, metamap, body, Req} return msg, start, err } @@ -425,15 +504,14 @@ func DecodeGzipBody(body []byte) []byte { return ret } -func readBytes(buf *bufio.Reader, size int) ([]byte, error) { - tempbytes := make([]byte, size) +func readBytes(buf *bufio.Reader, readSlice []byte, size int) error { var s, n = 0, 0 var err error for s < size && err == nil { - n, err = buf.Read(tempbytes[s:]) + n, err = buf.Read(readSlice[s:size]) s += n } - return tempbytes, err + return err } func EncodeGzip(data []byte) ([]byte, error) { @@ -523,7 +601,7 @@ func DecodeGzip(data []byte) (ret []byte, err error) { // ConvertToRequest convert motan2 protocol request message to motan Request func ConvertToRequest(request *Message, serialize motan.Serialization) (motan.Request, error) { - motanRequest := &motan.MotanRequest{Arguments: make([]interface{}, 0)} + motanRequest := motan.AcquireMotanRequest() motanRequest.RequestID = request.Header.RequestID if idStr, ok := request.Metadata.Load(MRequestID); !ok { if request.Header.IsProxy() { @@ -549,12 +627,11 @@ func ConvertToRequest(request *Message, serialize motan.Serialization) (motan.Re request.Header.SetGzip(false) } if !rc.Proxy && serialize == nil { + motan.ReleaseMotanRequest(motanRequest) return nil, ErrSerializeNil } - dv := &motan.DeserializableValue{Body: request.Body, Serialization: serialize} - motanRequest.Arguments = []interface{}{dv} + motanRequest.Arguments = append(motanRequest.Arguments, &motan.DeserializableValue{Body: request.Body, Serialization: serialize}) } - return motanRequest, nil } @@ -633,12 +710,11 @@ func ConvertToResMessage(response motan.Response, serialize motan.Serialization) return msg, nil } } - - res := &Message{} + res := AcquireMessage() var msgType int if response.GetException() != nil { msgType = Exception - response.SetAttachment(MExceptionn, ExceptionToJSON(response.GetException())) + response.SetAttachment(MException, ExceptionToJSON(response.GetException())) if rc.Proxy { rc.Serialized = true } @@ -660,11 +736,15 @@ func ConvertToResMessage(response motan.Response, serialize motan.Serialization) res.Body = b } else { vlog.Warningf("convert response value fail! serialized value not []byte. res:%+v", response) + res.SetCanRelease() + ReleaseMessage(res) return nil, ErrSerializedData } } else { b, err := serialize.Serialize(response.GetValue()) if err != nil { + res.SetCanRelease() + ReleaseMessage(res) return nil, err } res.Body = b @@ -683,7 +763,7 @@ func ConvertToResMessage(response motan.Response, serialize motan.Serialization) // ConvertToResponse convert protocol response to motan Response func ConvertToResponse(response *Message, serialize motan.Serialization) (motan.Response, error) { - mres := &motan.MotanResponse{} + mres := motan.AcquireMotanResponse() rc := mres.GetRPCContext(true) rc.Proxy = response.Header.IsProxy() mres.RequestID = response.Header.RequestID @@ -699,17 +779,19 @@ func ConvertToResponse(response *Message, serialize motan.Serialization) (motan. response.Header.SetGzip(false) } if !rc.Proxy && serialize == nil { + motan.ReleaseMotanResponse(mres) return nil, ErrSerializeNil } dv := &motan.DeserializableValue{Body: response.Body, Serialization: serialize} mres.Value = dv } if response.Header.GetStatus() == Exception { - e := response.Metadata.LoadOrEmpty(MExceptionn) + e := response.Metadata.LoadOrEmpty(MException) if e != "" { var exception *motan.Exception err := json.Unmarshal([]byte(e), &exception) if err != nil { + motan.ReleaseMotanResponse(mres) return nil, err } mres.Exception = exception @@ -722,13 +804,29 @@ func ConvertToResponse(response *Message, serialize motan.Serialization) (motan. } func BuildExceptionResponse(requestID uint64, errmsg string) *Message { - header := BuildHeader(Res, false, defaultSerialize, requestID, Exception) - msg := &Message{Header: header, Metadata: motan.NewStringMap(DefaultMetaSize)} - msg.Metadata.Store(MExceptionn, errmsg) - return msg + message := AcquireMessage() + message.Header.RequestID = requestID + message.Header.SetRequest(false) + message.Header.SetProxy(false) + message.Header.SetVersion(Version2) + message.Header.SetStatus(Exception) + message.Header.SetSerialize(defaultSerialize) + message.Metadata.Store(MException, errmsg) + + return message } func ExceptionToJSON(e *motan.Exception) string { errmsg, _ := json.Marshal(e) return string(errmsg) } + +func AcquireMessage() *Message { + return messagePool.Get().(*Message) +} + +func ReleaseMessage(msg *Message) { + if msg != nil && msg.Reset() { + messagePool.Put(msg) + } +} diff --git a/protocol/motanProtocol_test.go b/protocol/motanProtocol_test.go index 060ae0cb..81563f16 100644 --- a/protocol/motanProtocol_test.go +++ b/protocol/motanProtocol_test.go @@ -5,8 +5,10 @@ import ( "bytes" "compress/gzip" "fmt" + "github.com/stretchr/testify/assert" "github.com/weibocom/motan-go/serialize" "math/rand" + "strconv" "sync" "sync/atomic" "testing" @@ -145,7 +147,8 @@ func TestEncode(t *testing.T) { ebytes := msg.Encode() fmt.Println("len:", ebytes.Len()) - newMsg, err := Decode(bufio.NewReader(ebytes)) + readSlice := make([]byte, 100) + newMsg, err := Decode(bufio.NewReader(ebytes), &readSlice) if newMsg == nil { t.Fatalf("encode message fail") } @@ -164,12 +167,11 @@ func TestEncode(t *testing.T) { msg.Header.SetGzip(true) msg.Body, _ = EncodeGzip([]byte("gzip encode")) b := msg.Encode() - newMsg, _ = Decode(bufio.NewReader(b)) + newMsg, _ = Decode(bufio.NewReader(b), &readSlice) // should not decode gzip if !newMsg.Header.IsGzip() { t.Fatalf("encode message fail") } - nb, err := DecodeGzip(newMsg.Body) if err != nil { t.Errorf("decode gzip fail. err:%v", err) @@ -177,20 +179,127 @@ func TestEncode(t *testing.T) { assertTrue(string(nb) == "gzip encode", "body", t) } +func TestMessage_GetEncodedBytes(t *testing.T) { + h := &Header{} + h.SetVersion(Version2) + h.SetStatus(6) + h.SetOneWay(true) + h.SetSerialize(5) + h.SetGzip(true) + h.SetHeartbeat(true) + h.SetProxy(true) + h.SetRequest(true) + h.Magic = MotanMagic + h.RequestID = 2349789 + meta := core.NewStringMap(0) + meta.Store("k1", "v1") + body := []byte("testbody") + msg := &Message{Header: h, Metadata: meta, Body: body} + msg.Encode0() + encodedBytes := msg.GetEncodedBytes() + buf := core.CreateBytesBuffer(encodedBytes[0]) + + assert.Equal(t, "testbody", string(encodedBytes[1])) + + // append body to buf + buf.Write(msg.Body) + // verify decode + readSlice := make([]byte, 100) + newMsg, err := Decode(bufio.NewReader(buf), &readSlice) + if err != nil || newMsg == nil { + t.Fatalf("encode message fail") + } + + // verify header + assertTrue(newMsg.Header.IsOneWay(), "oneway", t) + assertTrue(newMsg.Header.IsGzip(), "gzip", t) + assertTrue(newMsg.Header.IsHeartbeat(), "heartbeat", t) + assertTrue(newMsg.Header.IsProxy(), "proxy", t) + assertTrue(newMsg.Header.isRequest(), "request", t) + assertTrue(newMsg.Header.GetVersion() == Version2, "version", t) + assertTrue(newMsg.Header.GetSerialize() == 5, "serialize", t) + assertTrue(newMsg.Header.GetStatus() == 6, "status", t) + + // verify meta + assertTrue(newMsg.Metadata.LoadOrEmpty("k1") == "v1", "meta", t) + + // verify body nil + assertTrue(len(newMsg.Body) == len(msg.Body), "body", t) +} + +func TestPool(t *testing.T) { + h := &Header{} + h.SetVersion(Version2) + h.SetStatus(6) + h.SetOneWay(true) + h.SetSerialize(5) + h.SetGzip(true) + h.SetHeartbeat(true) + h.SetProxy(true) + h.SetRequest(true) + h.Magic = MotanMagic + h.RequestID = 2349789 + meta := core.NewStringMap(0) + for mi := 0; mi < 10000; mi++ { + meta.Store(strconv.Itoa(mi), strconv.Itoa(mi)) + } + body := []byte("testbodytestbodytestbodytestbodytestbodytestbodytestbodytestbodytestbodytestbodytestbody") + msg := &Message{Header: h, Metadata: meta, Body: body} + ebytes := msg.Encode() + + fmt.Println("len:", ebytes.Len()) + readSlice := make([]byte, 100) + newMsg, err := Decode(bufio.NewReader(ebytes), &readSlice) + if newMsg == nil { + t.Fatalf("encode message fail") + } + assertTrue(newMsg.Header.IsOneWay(), "oneway", t) + assertTrue(newMsg.Header.IsGzip(), "gzip", t) + assertTrue(newMsg.Header.IsHeartbeat(), "heartbeat", t) + assertTrue(newMsg.Header.IsProxy(), "proxy", t) + assertTrue(newMsg.Header.isRequest(), "request", t) + assertTrue(newMsg.Header.GetVersion() == Version2, "version", t) + assertTrue(newMsg.Header.GetSerialize() == 5, "serialize", t) + assertTrue(newMsg.Header.GetStatus() == 6, "status", t) + assertTrue(newMsg.Metadata.LoadOrEmpty("1") == "1", "meta", t) + assertTrue(cap(readSlice) > 200, "readSlice", t) + assertTrue(len(newMsg.Body) == len(msg.Body), "body", t) + assert.Nil(t, err) + ReleaseMessage(newMsg) + body1 := []byte("testbody") + msg1 := &Message{Header: h, Metadata: meta, Body: body1} + ebytes1 := msg1.Encode() + newMsg, err = Decode(bufio.NewReader(ebytes1), &readSlice) + if newMsg == nil { + t.Fatalf("encode message fail") + } + assertTrue(newMsg.Header.IsOneWay(), "oneway", t) + assertTrue(newMsg.Header.IsGzip(), "gzip", t) + assertTrue(newMsg.Header.IsHeartbeat(), "heartbeat", t) + assertTrue(newMsg.Header.IsProxy(), "proxy", t) + assertTrue(newMsg.Header.isRequest(), "request", t) + assertTrue(newMsg.Header.GetVersion() == Version2, "version", t) + assertTrue(newMsg.Header.GetSerialize() == 5, "serialize", t) + assertTrue(newMsg.Header.GetStatus() == 6, "status", t) + assertTrue(newMsg.Metadata.LoadOrEmpty("1") == "1", "meta", t) + assertTrue(cap(readSlice) > 200, "readSlice", t) + assertTrue(len(newMsg.Body) == len(msg1.Body), "body", t) +} + func assertTrue(b bool, msg string, t *testing.T) { if !b { t.Fatalf("test fail, %s not correct.", msg) } } -//TODO convert -func TestConvertToRequest(t *testing.T) { +func TestConvertToResponse(t *testing.T) { h := &Header{} h.SetVersion(Version2) h.SetStatus(6) h.SetOneWay(true) h.SetSerialize(6) - h.SetGzip(true) + h.SetStatus(0) + h.SetGzip(false) h.SetHeartbeat(true) h.SetProxy(true) h.SetRequest(true) @@ -203,13 +312,160 @@ func TestConvertToRequest(t *testing.T) { meta.Store(MPath, "path") body := []byte("testbody") msg := &Message{Header: h, Metadata: meta, Body: body} - req, err := ConvertToRequest(msg, &serialize.SimpleSerialization{}) - assertTrue(err == nil, "conver to request err", t) - assertTrue(req.GetAttachment(MGroup) == "group", "request group", t) - assertTrue(req.GetAttachment(MMethod) == "method", "request method", t) - assertTrue(req.GetAttachment(MPath) == "path", "request path", t) + // To test convert when the method use pool + pMap := make(map[string]string) + for i := 0; i < 10000; i++ { + resp, err := ConvertToResponse(msg, &serialize.SimpleSerialization{}) + assertTrue(err == nil, "conver to request err", t) + assertTrue(resp.GetAttachment(MGroup) == "group", "response group", t) + assertTrue(resp.GetAttachment(MMethod) == "method", "response method", t) + assertTrue(resp.GetAttachment(MPath) == "path", "response path", t) + assertTrue(string(resp.GetValue().(*core.DeserializableValue).Body) == "testbody", "response body", t) + pMap[fmt.Sprintf("%p", resp)] = "1" + core.ReleaseMotanResponse(resp.(*core.MotanResponse)) + } + // check if responses are reused + assert.True(t, len(pMap) < 10000) + // To test if convert is correct when the method is called in concurrent situation + h1 := &Header{} + h1.SetVersion(Version2) + h1.SetStatus(1) + h1.SetOneWay(true) + h1.SetSerialize(6) + h1.SetGzip(false) + h1.SetHeartbeat(true) + h1.SetProxy(true) + h1.SetRequest(true) + h1.Magic = MotanMagic + h1.RequestID = 1234456 + meta1 := core.NewStringMap(0) + meta1.Store("k2", "v2") + meta1.Store(MGroup, "group1") + meta1.Store(MMethod, "method1") + meta1.Store(MPath, "path1") + meta1.Store(MException, `{"errcode": 0, "errmsg": "test exception", "errtype": 1}`) + msg1 := &Message{Header: h1, Metadata: meta1, Body: nil} + wg := sync.WaitGroup{} + wg.Add(2) + go func() { + time.Sleep(time.Millisecond * 500) + for i := 0; i < 10000; i++ { + resp, err := ConvertToResponse(msg, &serialize.SimpleSerialization{}) + assertTrue(resp.GetRequestID() == 2349789, "request id is incorrect", t) + assertTrue(err == nil, "convert to response err", t) + assertTrue(resp.GetAttachment(MGroup) == "group", "response group", t) + assertTrue(resp.GetAttachment(MMethod) == "method", "response method", t) + assertTrue(resp.GetAttachment(MPath) == "path", "response path", t) + assertTrue(string(resp.GetValue().(*core.DeserializableValue).Body) == "testbody", "response body", t) + core.ReleaseMotanResponse(resp.(*core.MotanResponse)) + } + wg.Done() + }() + go func() { + time.Sleep(time.Millisecond * 500) + for i := 0; i < 10000; i++ { + resp, err := ConvertToResponse(msg1, &serialize.SimpleSerialization{}) + assertTrue(resp.GetRequestID() == 1234456, "request id is incorrect", t) + assertTrue(err == nil, "convert to response err", t) + assertTrue(resp.GetAttachment(MGroup) == "group1", "response group", t) + assertTrue(resp.GetAttachment(MMethod) == "method1", "response method", t) + assertTrue(resp.GetAttachment(MPath) == "path1", "response path", t) + assertTrue(resp.GetValue() == nil, "response body", t) + assertTrue(resp.GetException().ErrMsg == "test exception", "response exception error message", t) + assertTrue(resp.GetException().ErrCode == 0, "response exception error code", t) + assertTrue(resp.GetException().ErrType == 1, "response exception error type", t) + core.ReleaseMotanResponse(resp.(*core.MotanResponse)) + } + wg.Done() + }() + wg.Wait() +} +func TestConvertToRequest(t *testing.T) { + h := &Header{} + h.SetVersion(Version2) + h.SetStatus(6) + h.SetOneWay(true) + h.SetSerialize(6) + h.SetGzip(false) + h.SetHeartbeat(true) + h.SetProxy(true) + h.SetRequest(true) + h.Magic = MotanMagic + h.RequestID = 2349789 + meta := core.NewStringMap(0) + meta.Store("k1", "v1") + meta.Store(MGroup, "group") + meta.Store(MMethod, "method") + meta.Store(MPath, "path") + body := []byte("testbody") + msg := &Message{Header: h, Metadata: meta, Body: body} + pMap := make(map[string]string) + // To test convert when the method use pool + for i := 0; i < 10000; i++ { + req, err := ConvertToRequest(msg, &serialize.SimpleSerialization{}) + assertTrue(req.GetRequestID() == 2349789, "request id", t) + assertTrue(err == nil, "conver to request err", t) + assertTrue(req.GetAttachment(MGroup) == "group", "request group", t) + assertTrue(req.GetAttachment(MMethod) == "method", "request method", t) + assertTrue(req.GetAttachment(MPath) == "path", "request path", t) + assertTrue(len(req.GetArguments()) == 1, "request argument", t) + pMap[fmt.Sprintf("%p", req)] = "1" + core.ReleaseMotanRequest(req.(*core.MotanRequest)) + } + // check if requests are reused + assert.True(t, len(pMap) < 10000) + // To test if convert is correct when the method is called in concurrent situation + h1 := &Header{} + h1.SetVersion(Version2) + h1.SetStatus(6) + h1.SetOneWay(true) + h1.SetSerialize(6) + h1.SetGzip(false) + h1.SetHeartbeat(true) + h1.SetProxy(true) + h1.SetRequest(true) + h1.Magic = MotanMagic + h1.RequestID = 1234456 + meta1 := core.NewStringMap(0) + meta1.Store("k2", "v2") + meta1.Store(MGroup, "group1") + meta1.Store(MMethod, "method1") + meta1.Store(MPath, "path1") + msg1 := &Message{Header: h1, Metadata: meta1, Body: nil} + wg := sync.WaitGroup{} + wg.Add(2) + go func() { + time.Sleep(time.Millisecond * 500) + for i := 0; i < 10000; i++ { + req, err := ConvertToRequest(msg, &serialize.SimpleSerialization{}) + assertTrue(req.GetRequestID() == 2349789, "request id is incorrect", t) + assertTrue(err == nil, "conver to request err", t) + assertTrue(req.GetAttachment(MGroup) == "group", "request group", t) + assertTrue(req.GetAttachment(MMethod) == "method", "request method", t) + assertTrue(req.GetAttachment(MPath) == "path", "request path", t) + assertTrue(len(req.GetArguments()) == 1, "request argument", t) + core.ReleaseMotanRequest(req.(*core.MotanRequest)) + } + wg.Done() + }() + go func() { + time.Sleep(time.Millisecond * 500) + for i := 0; i < 10000; i++ { + req, err := ConvertToRequest(msg1, &serialize.SimpleSerialization{}) + assertTrue(req.GetRequestID() == 1234456, "request id is incorrect", t) + assertTrue(err == nil, "conver to request err", t) + assertTrue(req.GetAttachment(MGroup) == "group1", "request group", t) + assertTrue(req.GetAttachment(MMethod) == "method1", "request method", t) + assertTrue(req.GetAttachment(MPath) == "path1", "request path", t) + assertTrue(len(req.GetArguments()) == 0, "request argument", t) + core.ReleaseMotanRequest(req.(*core.MotanRequest)) + } + wg.Done() + }() + wg.Wait() // test request clone + req, err := ConvertToRequest(msg, &serialize.SimpleSerialization{}) cloneReq := req.Clone().(core.Request) assertTrue(err == nil, "conver to request err", t) assertTrue(cloneReq.GetAttachment(MGroup) == "group", "clone request group", t) @@ -337,6 +593,35 @@ func TestConcurrentGzip(t *testing.T) { fmt.Printf("count:%v, errCount: %v\n", count, errCount) } +func TestBuildExceptionResponse(t *testing.T) { + // BuildExceptionResponse + var requestId uint64 = 1234 + err := fmt.Errorf("test error") + exception := &core.Exception{ErrCode: 500, ErrMsg: err.Error(), ErrType: core.ServiceException} + msg := ExceptionToJSON(exception) + + // verify exception message + message := BuildExceptionResponse(requestId, msg) + assert.Equal(t, requestId, message.Header.RequestID) + assert.Equal(t, false, message.Header.isRequest()) + assert.Equal(t, Res, int(message.Header.MsgType)) + assert.Equal(t, false, message.Header.IsProxy()) + assert.Equal(t, Exception, message.Header.GetStatus()) + assert.Equal(t, msg, message.Metadata.LoadOrEmpty(MException)) + + buf := message.Encode() + readSlice := make([]byte, 100) + newMessage, err := Decode(bufio.NewReader(buf), &readSlice) + + // verify encode and decode exception message + assert.Equal(t, message.Header.RequestID, newMessage.Header.RequestID) + assert.Equal(t, message.Header.IsProxy(), newMessage.Header.IsProxy()) + assert.Equal(t, Res, int(message.Header.MsgType)) + assert.Equal(t, message.Header.isRequest(), newMessage.Header.isRequest()) + assert.Equal(t, Exception, newMessage.Header.GetStatus()) + assert.Equal(t, msg, message.Metadata.LoadOrEmpty(MException)) +} + func buildBytes(size int) []byte { baseBytes := []byte("0123456789abcdefghijklmnopqrstuvwxyz") result := bytes.NewBuffer(make([]byte, 0, size)) diff --git a/provider/httpProvider.go b/provider/httpProvider.go index f19cfe69..a5b47ec7 100644 --- a/provider/httpProvider.go +++ b/provider/httpProvider.go @@ -140,6 +140,23 @@ func (h *HTTPProvider) SetContext(context *motan.Context) { h.gctx = context } +// rewrite do rewrite +func (h *HTTPProvider) rewrite(httpReq *fasthttp.Request, request motan.Request) (string, error) { + if h.enableRewrite { + var query []byte + // init query string bytes if needed. + if h.locationMatcher.NeedURLQueryString() { + query = httpReq.URI().QueryString() + } + _, path, ok := h.locationMatcher.Pick(request.GetMethod(), query, true) + if !ok { + return "", errors.New("service not found") + } + return path, nil + } + return request.GetMethod(), nil +} + func buildReqURL(request motan.Request, h *HTTPProvider) (string, string, error) { method := request.GetMethod() httpReqURLFmt := h.url.Parameters["URL_FORMAT"] @@ -219,162 +236,135 @@ func buildQueryStr(request motan.Request, url *motan.URL, mixVars []string) (res return res, err } -// Call for do a motan call through this provider -func (h *HTTPProvider) Call(request motan.Request) motan.Response { - t := time.Now().UnixNano() - resp := &motan.MotanResponse{Attachment: motan.NewStringMap(motan.DefaultAttachmentSize)} +func (h *HTTPProvider) DoTransparentProxy(request motan.Request, t int64, ip string) motan.Response { + resp := mhttp.AcquireHttpMotanResponse() + resp.RequestID = request.GetRequestID() var headerBytes []byte var bodyBytes []byte - doTransparentProxy, _ := strconv.ParseBool(request.GetAttachment(mhttp.Proxy)) - var toType []interface{} - if doTransparentProxy { - // Header and body with []byte - toType = []interface{}{&headerBytes, &bodyBytes} - } else if h.proxyAddr != "" { - toType = nil - } else { - toType = make([]interface{}, 1) - } + toType := []interface{}{&headerBytes, &bodyBytes} if err := request.ProcessDeserializable(toType); err != nil { - fillExceptionWithCode(resp, http.StatusBadRequest, t, err) + fillHttpException(resp, http.StatusBadRequest, t, err.Error()) return resp } - resp.RequestID = request.GetRequestID() - ip := "" - if remoteIP, exist := request.GetAttachments().Load(motan.RemoteIPKey); exist { - ip = remoteIP - } else { - ip = request.GetAttachment(motan.HostKey) + // acquires new fasthttp Request and Response object + httpReq := fasthttp.AcquireRequest() + httpRes := fasthttp.AcquireResponse() + // only release fast http request. The response will be released when Response is released + defer fasthttp.ReleaseRequest(httpReq) + // read http header into Request + httpReq.Header.Read(bufio.NewReader(bytes.NewReader(headerBytes))) + //do rewrite + rewritePath := request.GetMethod() + var err error + rewritePath, err = h.rewrite(httpReq, request) + if err != nil { + fillHttpException(resp, http.StatusNotFound, t, err.Error()) + return resp } - // Ok here we do transparent http proxy and return - if doTransparentProxy { - // acquires new fasthttp Request and Response object - httpReq := fasthttp.AcquireRequest() - httpRes := fasthttp.AcquireResponse() - defer fasthttp.ReleaseRequest(httpReq) - defer fasthttp.ReleaseResponse(httpRes) - // read http header into Request - httpReq.Header.Read(bufio.NewReader(bytes.NewReader(headerBytes))) - - //do rewrite - rewritePath := request.GetMethod() - if h.enableRewrite { - // Do not check upstream for compatibility - - var query []byte - // init query string bytes if needed. - if h.locationMatcher.NeedURLQueryString() { - query = httpReq.URI().QueryString() - } - _, path, ok := h.locationMatcher.Pick(request.GetMethod(), query, true) - if !ok { - fillExceptionWithCode(resp, http.StatusNotFound, t, errors.New("service not found")) - return resp - } - rewritePath = path - } - // sets rewrite - httpReq.URI().SetScheme(h.proxySchema) - httpReq.URI().SetPath(rewritePath) - request.GetAttachments().Range(func(k, v string) bool { - if strings.HasPrefix(k, "M_") { - httpReq.Header.Add(strings.Replace(k, "M_", "MOTAN-", -1), v) - } - return true - }) - httpReq.Header.Del("Connection") - if httpReq.Header.Peek(motan.XForwardedFor) == nil { - httpReq.Header.Set(motan.XForwardedFor, ip) - } - if len(bodyBytes) != 0 { - httpReq.BodyWriter().Write(bodyBytes) - } - err := h.fastClient.Do(httpReq, httpRes) - if err != nil { - fillExceptionWithCode(resp, http.StatusServiceUnavailable, t, err) - return resp - } - if h.enableHttpException { - if httpRes.StatusCode() >= 400 { - fillHttpException(resp, httpRes.StatusCode(), t, httpRes.Body()) - return resp - } + // sets rewrite + httpReq.URI().SetScheme(h.proxySchema) + httpReq.URI().SetPath(rewritePath) + request.GetAttachments().Range(func(k, v string) bool { + if kk, ok := mhttp.InnerAttachmentsConvertMap[k]; ok { + httpReq.Header.Set(kk, v) } - headerBuffer := &bytes.Buffer{} - httpRes.Header.Del("Connection") - httpRes.Header.WriteTo(headerBuffer) - body := httpRes.Body() - resp.ProcessTime = (time.Now().UnixNano() - t) / 1e6 - // copy response body is needed - responseBodyBytes := make([]byte, len(body)) - copy(responseBodyBytes, body) - resp.Value = []interface{}{headerBuffer.Bytes(), responseBodyBytes} - updateUpstreamStatusCode(resp, httpRes.StatusCode()) + return true + }) + httpReq.Header.Del("Connection") + if httpReq.Header.Peek(motan.XForwardedFor) == nil { + httpReq.Header.Set(motan.XForwardedFor, ip) + } + if len(bodyBytes) != 0 { + httpReq.BodyWriter().Write(bodyBytes) + } + err = h.fastClient.Do(httpReq, httpRes) + if err != nil { + fillHttpException(resp, http.StatusServiceUnavailable, t, err.Error()) return resp } + if h.enableHttpException && httpRes.StatusCode() >= 400 { + fillHttpException(resp, httpRes.StatusCode(), t, string(httpRes.Body())) + return resp + } + headerBuffer := &bytes.Buffer{} + httpRes.Header.Del("Connection") + httpRes.Header.WriteTo(headerBuffer) + body := httpRes.Body() + resp.ProcessTime = (time.Now().UnixNano() - t) / 1e6 + // record the response and release later + resp.HttpResponse = httpRes + resp.Value = []interface{}{headerBuffer.Bytes(), body} + updateUpstreamStatusCode(resp, httpRes.StatusCode()) + return resp +} - if h.proxyAddr != "" { - // rpc client call to this server - - // acquires new fasthttp Request and Response object - httpReq := fasthttp.AcquireRequest() - httpRes := fasthttp.AcquireResponse() - defer fasthttp.ReleaseRequest(httpReq) - defer fasthttp.ReleaseResponse(httpRes) - // convert motan request to fasthttp request - err := mhttp.MotanRequestToFasthttpRequest(request, httpReq, h.defaultHTTPMethod) - if err != nil { - fillExceptionWithCode(resp, http.StatusBadRequest, t, err) - return resp - } - rewritePath := request.GetMethod() - if h.enableRewrite { - var query []byte - // init query string bytes if needed. - if h.locationMatcher.NeedURLQueryString() { - query = httpReq.URI().QueryString() - } - _, path, ok := h.locationMatcher.Pick(request.GetMethod(), query, true) - if !ok { - fillExceptionWithCode(resp, http.StatusNotFound, t, errors.New("service not found")) - return resp - } - rewritePath = path - } - - httpReq.URI().SetScheme(h.proxySchema) - httpReq.URI().SetPath(rewritePath) - if len(httpReq.Header.Host()) == 0 { - httpReq.Header.SetHost(h.domain) - } - if httpReq.Header.Peek(motan.XForwardedFor) == nil { - httpReq.Header.Set(motan.XForwardedFor, ip) - } - err = h.fastClient.Do(httpReq, httpRes) - if err != nil { - fillExceptionWithCode(resp, http.StatusServiceUnavailable, t, err) - return resp - } - if h.enableHttpException { - if httpRes.StatusCode() >= 400 { - fillHttpException(resp, httpRes.StatusCode(), t, httpRes.Body()) - return resp - } - } - mhttp.FasthttpResponseToMotanResponse(resp, httpRes) - resp.ProcessTime = (time.Now().UnixNano() - t) / 1e6 - updateUpstreamStatusCode(resp, httpRes.StatusCode()) +// DoProxy deal with Request start from a rpc client +func (h *HTTPProvider) DoProxy(request motan.Request, t int64, ip string) motan.Response { + resp := mhttp.AcquireHttpMotanResponse() + resp.RequestID = request.GetRequestID() + if err := request.ProcessDeserializable(nil); err != nil { + fillHttpException(resp, http.StatusBadRequest, t, err.Error()) + return resp + } + // rpc client call to this server + + // acquires new fasthttp Request and Response object + httpReq := fasthttp.AcquireRequest() + httpRes := fasthttp.AcquireResponse() + // do not release http response + defer fasthttp.ReleaseRequest(httpReq) + // convert motan request to fasthttp request + err := mhttp.MotanRequestToFasthttpRequest(request, httpReq, h.defaultHTTPMethod) + if err != nil { + fillHttpException(resp, http.StatusBadRequest, t, err.Error()) + return resp + } + rewritePath := request.GetMethod() + rewritePath, err = h.rewrite(httpReq, request) + if err != nil { + fillHttpException(resp, http.StatusNotFound, t, err.Error()) + return resp + } + httpReq.URI().SetScheme(h.proxySchema) + httpReq.URI().SetPath(rewritePath) + if len(httpReq.Header.Host()) == 0 { + httpReq.Header.SetHost(h.domain) + } + if httpReq.Header.Peek(motan.XForwardedFor) == nil { + httpReq.Header.Set(motan.XForwardedFor, ip) + } + err = h.fastClient.Do(httpReq, httpRes) + if err != nil { + fillHttpException(resp, http.StatusServiceUnavailable, t, err.Error()) + return resp + } + if h.enableHttpException && httpRes.StatusCode() >= 400 { + fillHttpException(resp, httpRes.StatusCode(), t, string(httpRes.Body())) return resp } + mhttp.FasthttpResponseToMotanResponse(resp, httpRes) + resp.ProcessTime = (time.Now().UnixNano() - t) / 1e6 + updateUpstreamStatusCode(resp, httpRes.StatusCode()) + return resp +} +// DoFormatURLQuery use ordinary client and parse the format url +func (h *HTTPProvider) DoFormatURLQuery(request motan.Request, t int64, ip string) motan.Response { + resp := mhttp.AcquireHttpMotanResponse() + resp.RequestID = request.GetRequestID() + toType := make([]interface{}, 1) + if err := request.ProcessDeserializable(toType); err != nil { + fillHttpException(resp, http.StatusBadRequest, t, err.Error()) + return resp + } httpReqURL, httpReqMethod, err := buildReqURL(request, h) if err != nil { - fillException(resp, t, err) + fillHttpException(resp, http.StatusServiceUnavailable, t, err.Error()) return resp } queryStr, err := buildQueryStr(request, h.url, h.mixVars) if err != nil { - fillException(resp, t, err) + fillHttpException(resp, http.StatusServiceUnavailable, t, err.Error()) return resp } var reqBody io.Reader @@ -390,7 +380,7 @@ func (h *HTTPProvider) Call(request motan.Request) motan.Response { req, err := http.NewRequest(httpReqMethod, httpReqURL, reqBody) if err != nil { vlog.Errorf("new HTTP Provider NewRequest err: %v", err) - fillException(resp, t, err) + fillHttpException(resp, http.StatusServiceUnavailable, t, err.Error()) return resp } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") //设置后,post参数才可正常传递 @@ -421,7 +411,7 @@ func (h *HTTPProvider) Call(request motan.Request) motan.Response { httpResp, err := c.Do(req) if err != nil { vlog.Errorf("new HTTP Provider Do HTTP Call err: %v", err) - fillException(resp, t, err) + fillHttpException(resp, http.StatusServiceUnavailable, t, err.Error()) return resp } defer httpResp.Body.Close() @@ -432,18 +422,16 @@ func (h *HTTPProvider) Call(request motan.Request) motan.Response { if l == 0 { vlog.Warningf("server_agent result is empty :%d,%d,%s", statusCode, request.GetRequestID(), httpReqURL) } - resp.ProcessTime = int64((time.Now().UnixNano() - t) / 1e6) + resp.ProcessTime = (time.Now().UnixNano() - t) / 1e6 if err != nil { vlog.Errorf("new HTTP Provider Read body err: %v", err) resp.Exception = &motan.Exception{ErrCode: statusCode, ErrMsg: fmt.Sprintf("%s", err), ErrType: http.StatusServiceUnavailable} return resp } - if h.enableHttpException { - if statusCode >= 400 { - fillHttpException(resp, statusCode, t, body) - return resp - } + if h.enableHttpException && statusCode >= 400 { + fillHttpException(resp, statusCode, t, string(body)) + return resp } request.GetAttachments().Range(func(k, v string) bool { resp.SetAttachment(k, v) @@ -457,6 +445,26 @@ func (h *HTTPProvider) Call(request motan.Request) motan.Response { return resp } +// Call for do a motan call through this provider +func (h *HTTPProvider) Call(request motan.Request) motan.Response { + t := time.Now().UnixNano() + doTransparentProxy, _ := strconv.ParseBool(request.GetAttachment(mhttp.Proxy)) + ip := "" + if remoteIP, exist := request.GetAttachments().Load(motan.RemoteIPKey); exist { + ip = remoteIP + } else { + ip = request.GetAttachment(motan.HostKey) + } + // Ok here we do transparent http proxy and return + if doTransparentProxy { + return h.DoTransparentProxy(request, t, ip) + } + if h.proxyAddr != "" { + return h.DoProxy(request, t, ip) + } + return h.DoFormatURLQuery(request, t, ip) +} + // GetName return this provider name func (h *HTTPProvider) GetName() string { return "HTTPProvider" @@ -498,21 +506,23 @@ func (h *HTTPProvider) GetPath() string { } func fillExceptionWithCode(resp *motan.MotanResponse, code int, start int64, err error) { - resp.ProcessTime = int64((time.Now().UnixNano() - start) / 1e6) + resp.ProcessTime = (time.Now().UnixNano() - start) / 1e6 resp.Exception = &motan.Exception{ErrCode: code, ErrMsg: fmt.Sprintf("%s", err), ErrType: code} } -func fillHttpException(resp *motan.MotanResponse, statusCode int, start int64, body []byte) { - resp.ProcessTime = int64((time.Now().UnixNano() - start) / 1e6) - resp.Exception = &motan.Exception{ErrCode: statusCode, ErrMsg: string(body), ErrType: motan.BizException} +func fillHttpExceptionWithCode(resp *mhttp.HttpMotanResponse, statusCode int, errType int, start int64, msg string) { + resp.ProcessTime = (time.Now().UnixNano() - start) / 1e6 + resp.Exception = &motan.Exception{ErrCode: statusCode, ErrMsg: msg, ErrType: errType} +} + +func fillHttpException(resp *mhttp.HttpMotanResponse, statusCode int, start int64, msg string) { + fillHttpExceptionWithCode(resp, statusCode, motan.BizException, start, msg) } func fillException(resp *motan.MotanResponse, start int64, err error) { fillExceptionWithCode(resp, http.StatusServiceUnavailable, start, err) } -func updateUpstreamStatusCode(resp *motan.MotanResponse, statusCode int) { - resCtx := resp.GetRPCContext(true) +func updateUpstreamStatusCode(resp *mhttp.HttpMotanResponse, statusCode int) { resp.SetAttachment(motan.MetaUpstreamCode, strconv.Itoa(statusCode)) - resCtx.Meta.Store(motan.MetaUpstreamCode, strconv.Itoa(statusCode)) } diff --git a/provider/motanProvider.go b/provider/motanProvider.go index c885d216..b2c5f860 100644 --- a/provider/motanProvider.go +++ b/provider/motanProvider.go @@ -65,7 +65,7 @@ func (m *MotanProvider) Call(request motan.Request) motan.Response { return m.ep.Call(request) } t := time.Now().UnixNano() - res := &motan.MotanResponse{Attachment: motan.NewStringMap(motan.DefaultAttachmentSize)} + res := motan.AcquireMotanResponse() fillException(res, t, errors.New("reverse proxy call err: motanProvider is unavailable")) return res } diff --git a/registry/consulRegistry.go b/registry/consulRegistry.go index 943583e7..ae1c1168 100644 --- a/registry/consulRegistry.go +++ b/registry/consulRegistry.go @@ -16,9 +16,11 @@ type ConsulRegistry struct { func (v *ConsulRegistry) GetURL() *motan.URL { return v.url } + func (v *ConsulRegistry) SetURL(url *motan.URL) { v.url = url } + func (v *ConsulRegistry) GetName() string { return "consulRegistry" } @@ -26,6 +28,7 @@ func (v *ConsulRegistry) GetName() string { func (v *ConsulRegistry) Initialize() { } + func (v *ConsulRegistry) Subscribe(url *motan.URL, listener motan.NotifyListener) { } @@ -33,22 +36,29 @@ func (v *ConsulRegistry) Subscribe(url *motan.URL, listener motan.NotifyListener func (v *ConsulRegistry) Unsubscribe(url *motan.URL, listener motan.NotifyListener) { } + func (v *ConsulRegistry) Discover(url *motan.URL) []*motan.URL { return nil } + func (v *ConsulRegistry) Register(serverURL *motan.URL) { } + func (v *ConsulRegistry) UnRegister(serverURL *motan.URL) { } + func (v *ConsulRegistry) Available(serverURL *motan.URL) { } + func (v *ConsulRegistry) Unavailable(serverURL *motan.URL) { } + func (v *ConsulRegistry) GetRegisteredServices() []*motan.URL { return nil } + func (v *ConsulRegistry) StartSnapshot(conf *motan.SnapshotConf) {} diff --git a/registry/directRegistry.go b/registry/directRegistry.go index fcdc6876..9bc46783 100644 --- a/registry/directRegistry.go +++ b/registry/directRegistry.go @@ -16,21 +16,25 @@ type DirectRegistry struct { func (d *DirectRegistry) GetURL() *motan.URL { return d.url } + func (d *DirectRegistry) SetURL(url *motan.URL) { d.url = url d.urls = parseURLs(url) } + func (d *DirectRegistry) GetName() string { return "direct" } func (d *DirectRegistry) InitRegistry() { } + func (d *DirectRegistry) Subscribe(url *motan.URL, listener motan.NotifyListener) { } func (d *DirectRegistry) Unsubscribe(url *motan.URL, listener motan.NotifyListener) { } + func (d *DirectRegistry) Discover(url *motan.URL) []*motan.URL { if d.urls == nil { d.urls = parseURLs(d.url) @@ -47,22 +51,29 @@ func (d *DirectRegistry) Discover(url *motan.URL) []*motan.URL { } return result } + func (d *DirectRegistry) Register(serverURL *motan.URL) { vlog.Infof("direct registry:register url :%+v", serverURL) } + func (d *DirectRegistry) UnRegister(serverURL *motan.URL) { } + func (d *DirectRegistry) Available(serverURL *motan.URL) { } + func (d *DirectRegistry) Unavailable(serverURL *motan.URL) { } + func (d *DirectRegistry) GetRegisteredServices() []*motan.URL { return nil } + func (d *DirectRegistry) StartSnapshot(conf *motan.SnapshotConf) {} + func parseURLs(url *motan.URL) []*motan.URL { urls := make([]*motan.URL, 0) if len(url.Host) > 0 && url.Port > 0 { diff --git a/registry/localRegistry.go b/registry/localRegistry.go index 7cff8316..cda744b9 100644 --- a/registry/localRegistry.go +++ b/registry/localRegistry.go @@ -15,6 +15,7 @@ func (d *LocalRegistry) GetURL() *motan.URL { func (d *LocalRegistry) SetURL(url *motan.URL) { d.url = url } + func (d *LocalRegistry) GetName() string { return "local" } @@ -40,12 +41,15 @@ func (d *LocalRegistry) Register(serverURL *motan.URL) { func (d *LocalRegistry) UnRegister(serverURL *motan.URL) { } + func (d *LocalRegistry) Available(serverURL *motan.URL) { } + func (d *LocalRegistry) Unavailable(serverURL *motan.URL) { } + func (d *LocalRegistry) GetRegisteredServices() []*motan.URL { return nil } diff --git a/registry/zkRegistry_test.go b/registry/zkRegistry_test.go index 534c8d28..4af4a75b 100644 --- a/registry/zkRegistry_test.go +++ b/registry/zkRegistry_test.go @@ -35,7 +35,7 @@ var ( z = &ZkRegistry{} ) -//Test path generation methods. +// Test path generation methods. func TestZkRegistryToPath(t *testing.T) { //Test path create methods. if p := toNodePath(testURL, zkNodeTypeServer); p != serverPath { diff --git a/server.go b/server.go index 808b8e82..63672ac2 100644 --- a/server.go +++ b/server.go @@ -113,6 +113,7 @@ func (m *MSContext) hashInt(s string) int { h.Write([]byte(s)) return int(h.Sum32()) } + func (m *MSContext) export(url *motan.URL) { defer motan.HandlePanic(nil) service := m.serviceImpls[url.Parameters[motan.RefKey]] diff --git a/server/motanserver.go b/server/motanserver.go index 29fbfc04..23e0059b 100644 --- a/server/motanserver.go +++ b/server/motanserver.go @@ -3,6 +3,8 @@ package server import ( "bufio" "errors" + "github.com/panjf2000/ants/v2" + mhttp "github.com/weibocom/motan-go/http" "net" "strconv" "strings" @@ -19,6 +21,15 @@ import ( var currentConnections int64 var motanServerOnce sync.Once +var processPool, _ = ants.NewPool(50000, ants.WithMaxBlockingTasks(1024)) + +func SetProcessPoolSize(size int) { + processPool.Tune(size) +} + +func GetProcessPoolSize() int { + return processPool.Cap() +} func incrConnections() { atomic.AddInt64(¤tConnections, 1) @@ -61,7 +72,7 @@ func (m *MotanServer) Open(block bool, proxy bool, handler motan.MessageHandler, } lis = listener } else { - addr := ":" + strconv.Itoa(int(m.URL.Port)) + addr := ":" + strconv.Itoa(m.URL.Port) if registry.IsAgent(m.URL) { addr = m.URL.Host + addr } @@ -151,7 +162,7 @@ func (m *MotanServer) handleConn(conn net.Conn) { } else { ip = getRemoteIP(conn.RemoteAddr().String()) } - + decodeBuf := make([]byte, mpro.DefaultBufferSize) for { v, err := mpro.CheckMotanVersion(buf) if err != nil { @@ -168,13 +179,15 @@ func (m *MotanServer) handleConn(conn net.Conn) { } go m.processV1(v1Msg, t, ip, conn) } else if v == mpro.Version2 { - msg, t, err := mpro.DecodeWithTime(buf, m.maxContextLength) + msg, t, err := mpro.DecodeWithTime(buf, &decodeBuf, m.maxContextLength) if err != nil { vlog.Warningf("decode motan v2 message fail! con:%s, err:%s.", conn.RemoteAddr().String(), err.Error()) break } - go m.processV2(msg, t, ip, conn) + processPool.Submit(func() { + m.processV2(msg, t, ip, conn) + }) } else { vlog.Warningf("unsupported motan version! version:%d con:%s", v, conn.RemoteAddr().String()) break @@ -219,7 +232,7 @@ func (m *MotanServer) processV2(msg *mpro.Message, start time.Time, ip string, c tc.PutReqSpan(&motan.Span{Name: motan.Convert, Time: time.Now()}) req.GetRPCContext(true).Tc = tc } - callStart := time.Now() + mres = m.handler.Call(req) if tc != nil { // clusterFilter end @@ -229,7 +242,7 @@ func (m *MotanServer) processV2(msg *mpro.Message, start time.Time, ip string, c resCtx := mres.GetRPCContext(true) resCtx.Proxy = m.proxy if mres.GetAttachment(mpro.MProcessTime) == "" { - mres.SetAttachment(mpro.MProcessTime, strconv.FormatInt(int64(time.Now().Sub(callStart)/1e6), 10)) + mres.SetAttachment(mpro.MProcessTime, strconv.FormatInt(int64(time.Now().Sub(start)/1e6), 10)) } res, err = mpro.ConvertToResMessage(mres, serialization) if tc != nil { @@ -246,12 +259,14 @@ func (m *MotanServer) processV2(msg *mpro.Message, start time.Time, ip string, c } // recover the communication identifier res.Header.RequestID = lastRequestID - resBuf := res.Encode() + res.Encode0() if tc != nil { tc.PutResSpan(&motan.Span{Name: motan.Encode, Time: time.Now()}) } + var sendBuf net.Buffers = res.GetEncodedBytes() conn.SetWriteDeadline(time.Now().Add(motan.DefaultWriteTimeout)) - _, err := conn.Write(resBuf.Bytes()) + _, err := sendBuf.WriteTo(conn) + res.SetCanRelease() if err != nil { vlog.Errorf("connection will close. conn: %s, err:%s", conn.RemoteAddr().String(), err.Error()) conn.Close() @@ -268,6 +283,18 @@ func (m *MotanServer) processV2(msg *mpro.Message, start time.Time, ip string, c if tc != nil { tc.PutResSpan(&motan.Span{Name: motan.Send, Time: resSendTime}) } + // 回收message + mpro.ReleaseMessage(msg) + mpro.ReleaseMessage(res) + // 回收request + if motanReq, ok := mreq.(*motan.MotanRequest); ok { + motan.ReleaseMotanRequest(motanReq) + } + if motanResp, ok := mres.(*motan.MotanResponse); ok { + motan.ReleaseMotanResponse(motanResp) + } else if motanHttpRes, ok := mres.(*mhttp.HttpMotanResponse); ok { + mhttp.ReleaseHttpMotanResponse(motanHttpRes) + } } func (m *MotanServer) processV1(msg *mpro.MotanV1Message, start time.Time, ip string, conn net.Conn) { @@ -276,6 +303,13 @@ func (m *MotanServer) processV1(msg *mpro.MotanV1Message, start time.Time, ip st var result []byte var reqCtx *motan.RPCContext req, err := mpro.DecodeMotanV1Request(msg) + // fill v2 attachment + if req.GetAttachment(mpro.MGroup) == "" { + req.SetAttachment(mpro.MGroup, req.GetAttachment(mpro.V1Group)) + } + if req.GetAttachment(mpro.MVersion) == "" { + req.SetAttachment(mpro.MVersion, req.GetAttachment(mpro.V1Version)) + } if err != nil { vlog.Errorf("decode v1 request fail. conn: %s, err:%s", conn.RemoteAddr().String(), err.Error()) result = mpro.BuildV1ExceptionResponse(msg.Rid, err.Error()) @@ -337,7 +371,7 @@ func getRemoteIP(address string) string { var ip string index := strings.Index(address, ":") if index > 0 { - ip = string(address[:index]) + ip = address[:index] } else { ip = address } diff --git a/tools/nginx/parser.go b/tools/nginx/parser.go index f8c28f20..118beca7 100644 --- a/tools/nginx/parser.go +++ b/tools/nginx/parser.go @@ -29,10 +29,11 @@ const ( Error ) -// location / { # directive location -// if ($request_uri ~= '/*') { # directive if -// } -// } +// Directive location / { # directive location +// +// if ($request_uri ~= '/*') { # directive if +// } +// } type Directive struct { name string args []string @@ -69,6 +70,7 @@ func init() { func NewParser(reader io.Reader) *Parser { return &Parser{reader: bufio.NewReader(reader)} } + func (p *Parser) readToken() token { buf := bytes.Buffer{} sharpComment := false