diff --git a/application/end.go b/application/end.go index ed922be..e33df6a 100644 --- a/application/end.go +++ b/application/end.go @@ -29,8 +29,9 @@ type opts struct { // delegate dlgt delegate.ApplicationDelegate // methods - remoteMethods []string - localMethods []*geminio.MethodRPC + remoteMethods []string + remoteMethodCheck bool + localMethods []*geminio.MethodRPC // callback funcs acceptStreamFunc func(geminio.Stream) closedStreamFunc func(geminio.Stream) @@ -73,6 +74,12 @@ func OptionWaitRemoteRPCs(methods ...string) EndOption { } } +func OptionWithRemoteRPCCheck() EndOption { + return func(end *End) { + end.remoteMethodCheck = true + } +} + func OptionRegisterLocalRPCs(methodRPCs ...*geminio.MethodRPC) EndOption { return func(end *End) { end.localMethods = methodRPCs diff --git a/application/rpc.go b/application/rpc.go index cd9c904..f3817b1 100644 --- a/application/rpc.go +++ b/application/rpc.go @@ -84,14 +84,16 @@ func (sm *stream) Call(ctx context.Context, method string, req geminio.Request, return nil, io.EOF } - // check remote RPC exists - sm.rpcMtx.RLock() - _, ok := sm.remoteRPCs[method] - if !ok { + if sm.opts.remoteMethodCheck { + // check remote RPC exists + sm.rpcMtx.RLock() + _, ok := sm.remoteRPCs[method] + if !ok { + sm.rpcMtx.RUnlock() + return nil, ErrRemoteRPCUnregistered + } sm.rpcMtx.RUnlock() - return nil, ErrRemoteRPCUnregistered } - sm.rpcMtx.RUnlock() // transfer to underlayer packet pkt := sm.pf.NewRequestPacketWithIDAndSessionID(req.ID(), sm.dg.DialogueID(), []byte(method), req.Data()) if req.Timeout() != 0 { diff --git a/client/end.go b/client/end.go index e9a4fb7..fe18b60 100644 --- a/client/end.go +++ b/client/end.go @@ -101,6 +101,9 @@ func new(netcn net.Conn, opts ...*EndOptions) (geminio.End, error) { application.OptionWaitRemoteRPCs(eo.RemoteMethods...), application.OptionRegisterLocalRPCs(eo.LocalMethods...), } + if eo.RemoteMethodCheck { + epOpts = append(epOpts, application.OptionWithRemoteRPCCheck()) + } ep, err = application.NewEnd(cn, mp, epOpts...) if err != nil { goto ERR diff --git a/client/end_options.go b/client/end_options.go index 311993a..9934cb4 100644 --- a/client/end_options.go +++ b/client/end_options.go @@ -11,16 +11,17 @@ import ( ) type EndOptions struct { - Timer timer.Timer - TimerOwner interface{} - PacketFactory packet.PacketFactory - Log log.Logger - Delegate delegate.ClientDelegate - delegate delegate.ClientDelegate - ClientID *uint64 - Meta []byte - RemoteMethods []string - LocalMethods []*geminio.MethodRPC + Timer timer.Timer + TimerOwner interface{} + PacketFactory packet.PacketFactory + Log log.Logger + Delegate delegate.ClientDelegate + delegate delegate.ClientDelegate + ClientID *uint64 + Meta []byte + RemoteMethods []string + RemoteMethodCheck bool + LocalMethods []*geminio.MethodRPC } func (eo *EndOptions) SetTimer(timer timer.Timer) { @@ -57,6 +58,10 @@ func (eo *EndOptions) SetWaitRemoteRPCs(methods ...string) { eo.RemoteMethods = methods } +func (eo *EndOptions) SetRemoteRPCCheck() { + eo.RemoteMethodCheck = true +} + func (eo *EndOptions) SetRegisterLocalRPCs(methodRPCs ...*geminio.MethodRPC) { eo.LocalMethods = methodRPCs } @@ -71,6 +76,7 @@ func MergeEndOptions(opts ...*EndOptions) *EndOptions { if opt == nil { continue } + eo.RemoteMethodCheck = opt.RemoteMethodCheck if opt.Timer != nil { eo.Timer = opt.Timer eo.TimerOwner = opt.TimerOwner diff --git a/client/end_retry_options.go b/client/end_retry_options.go index c8ed885..879ac99 100644 --- a/client/end_retry_options.go +++ b/client/end_retry_options.go @@ -24,6 +24,7 @@ func MergeRetryEndOptions(opts ...*RetryEndOptions) *RetryEndOptions { if opt == nil { continue } + eo.RemoteMethodCheck = opt.RemoteMethodCheck if opt.Timer != nil { eo.Timer = opt.Timer eo.TimerOwner = opt.TimerOwner diff --git a/examples/usage/brpc/server/main.go b/examples/usage/brpc/server/main.go index 683d072..8f367c2 100644 --- a/examples/usage/brpc/server/main.go +++ b/examples/usage/brpc/server/main.go @@ -13,7 +13,7 @@ func main() { // the option means all End from server will wait for the rpc registration opt.SetWaitRemoteRPCs("client-echo") // pre-register server side method - opt.SetRegisterLocalRPCs(&geminio.MethodRPC{"server-echo", echo}) + opt.SetRegisterLocalRPCs(&geminio.MethodRPC{Method: "server-echo", RPC: echo}) ln, err := server.Listen("tcp", "127.0.0.1:8080", opt) if err != nil { diff --git a/server/end.go b/server/end.go index 890778b..6fcdccb 100644 --- a/server/end.go +++ b/server/end.go @@ -109,6 +109,9 @@ func new(netcn net.Conn, opts ...*EndOptions) (geminio.End, error) { if eo.ClosedStreamFunc != nil { epOpts = append(epOpts, application.OptionClosedStreamFunc(eo.ClosedStreamFunc)) } + if eo.RemoteMethodCheck { + epOpts = append(epOpts, application.OptionWithRemoteRPCCheck()) + } ep, err = application.NewEnd(cn, mp, epOpts...) if err != nil { goto ERR diff --git a/server/end_options.go b/server/end_options.go index 8a9d6e9..dbc880c 100644 --- a/server/end_options.go +++ b/server/end_options.go @@ -10,14 +10,15 @@ import ( ) type EndOptions struct { - Timer timer.Timer - TimerOwner interface{} - PacketFactory packet.PacketFactory - Log log.Logger - Delegate delegate.ServerDelegate - ClientID *uint64 - RemoteMethods []string - LocalMethods []*geminio.MethodRPC + Timer timer.Timer + TimerOwner interface{} + PacketFactory packet.PacketFactory + Log log.Logger + Delegate delegate.ServerDelegate + ClientID *uint64 + RemoteMethods []string + RemoteMethodCheck bool + LocalMethods []*geminio.MethodRPC // If set AcceptStreamFunc, the AcceptStream should never be called AcceptStreamFunc func(geminio.Stream) ClosedStreamFunc func(geminio.Stream) @@ -48,6 +49,10 @@ func (eo *EndOptions) SetWaitRemoteRPCs(methods ...string) { eo.RemoteMethods = methods } +func (eo *EndOptions) SetRemoteRPCCheck() { + eo.RemoteMethodCheck = true +} + func (eo *EndOptions) SetRegisterLocalRPCs(methodRPCs ...*geminio.MethodRPC) { eo.LocalMethods = methodRPCs } @@ -70,6 +75,7 @@ func MergeEndOptions(opts ...*EndOptions) *EndOptions { if opt == nil { continue } + eo.RemoteMethodCheck = opt.RemoteMethodCheck if opt.Timer != nil { eo.Timer = opt.Timer eo.TimerOwner = opt.TimerOwner diff --git a/server/server.go b/server/server.go index 2f56aff..ef13e49 100644 --- a/server/server.go +++ b/server/server.go @@ -22,9 +22,15 @@ type Listener interface { Addr() net.Addr } +type ret struct { + end geminio.End + err error +} + type listener struct { opts []*EndOptions ln net.Listener + ch chan *ret } func Listen(network, address string, opts ...*EndOptions) (Listener, error) { @@ -32,7 +38,10 @@ func Listen(network, address string, opts ...*EndOptions) (Listener, error) { if err != nil { return nil, err } - return &listener{ln: ln, opts: opts}, nil + return &listener{ + ln: ln, + opts: opts, + ch: make(chan *ret, 128)}, nil } func (ln *listener) AcceptEnd() (geminio.End, error) { @@ -40,8 +49,12 @@ func (ln *listener) AcceptEnd() (geminio.End, error) { if err != nil { return nil, err } - end, err := NewEndWithConn(netconn, ln.opts...) - return end, err + go func() { + end, err := NewEndWithConn(netconn, ln.opts...) + ln.ch <- &ret{end, err} + }() + ret, _ := <-ln.ch + return ret.end, ret.err } func (ln *listener) Accept() (net.Conn, error) { diff --git a/test/regression/regression_test.go b/test/regression/regression_test.go index 111c073..cfa3849 100644 --- a/test/regression/regression_test.go +++ b/test/regression/regression_test.go @@ -2,10 +2,15 @@ package regression import ( "context" + "encoding/binary" "errors" + "sync" + "sync/atomic" "testing" "github.com/singchia/geminio" + "github.com/singchia/geminio/client" + "github.com/singchia/geminio/server" "github.com/singchia/geminio/test" ) @@ -66,3 +71,53 @@ func TestMessage(t *testing.T) { <-done } + +func TestServer(t *testing.T) { + network := "tcp" + address := "127.0.0.1:12345" + srv, err := server.Listen(network, address) + if err != nil { + t.Error(err) + return + } + wg := sync.WaitGroup{} + wg.Add(1) + + index, count := uint64(0), uint64(1000) + accepted := make([]uint64, count) + go func() { + for { + end, err := srv.AcceptEnd() + if err != nil { + t.Error(err) + return + } + meta := end.Meta() + id := binary.BigEndian.Uint64(meta) + accepted[id-1] = 1 + new := atomic.AddUint64(&index, 1) + if new == count { + wg.Done() + } + } + }() + + for i := uint64(0); i < count; i++ { + opt := client.NewEndOptions() + buf := make([]byte, 8) + binary.BigEndian.PutUint64(buf, i+1) + opt.SetMeta(buf) + _, err := client.NewEnd(network, address, opt) + if err != nil { + t.Error(err) + } + } + + wg.Wait() + for _, elem := range accepted { + if elem != 1 { + t.Error("failed end exist") + return + } + } +}