diff --git a/pkg/streamx/provider/ttstream/client_trans_pool_mux.go b/pkg/streamx/provider/ttstream/client_trans_pool_mux.go index 4775cb40bc..1d524fc5cf 100644 --- a/pkg/streamx/provider/ttstream/client_trans_pool_mux.go +++ b/pkg/streamx/provider/ttstream/client_trans_pool_mux.go @@ -19,6 +19,7 @@ package ttstream import ( "runtime" "sync" + "sync/atomic" "time" "github.com/cloudwego/kitex/pkg/serviceinfo" @@ -28,43 +29,76 @@ import ( var _ transPool = (*muxTransPool)(nil) +type muxTransList struct { + L sync.RWMutex + size int + cursor uint32 + transports []*transport +} + +func newMuxTransList(size int) *muxTransList { + tl := new(muxTransList) + tl.size = size + tl.transports = make([]*transport, size) + return tl +} + +func (tl *muxTransList) Get(sinfo *serviceinfo.ServiceInfo, network string, addr string) (*transport, error) { + idx := atomic.AddUint32(&tl.cursor, 1) % uint32(tl.size) + tl.L.RLock() + trans := tl.transports[idx] + tl.L.RUnlock() + if trans != nil && trans.IsActive() { + return trans, nil + } + + conn, err := netpoll.DialConnection(network, addr, time.Second) + if err != nil { + return nil, err + } + trans = newTransport(clientTransport, sinfo, conn) + _ = conn.AddCloseCallback(func(connection netpoll.Connection) error { + // peer close + _ = trans.Close() + return nil + }) + runtime.SetFinalizer(trans, func(trans *transport) { + // self close when not hold by user + _ = trans.Close() + }) + tl.L.Lock() + tl.transports[idx] = trans + tl.L.Unlock() + return trans, nil +} + func newMuxTransPool() transPool { t := new(muxTransPool) + t.poolSize = runtime.GOMAXPROCS(0) return t } type muxTransPool struct { - pool sync.Map // addr:*transport - sflight singleflight.Group + poolSize int + pool sync.Map // addr:*muxTransList + sflight singleflight.Group } func (m *muxTransPool) Get(sinfo *serviceinfo.ServiceInfo, network string, addr string) (trans *transport, err error) { v, ok := m.pool.Load(addr) if ok { - return v.(*transport), nil + return v.(*muxTransList).Get(sinfo, network, addr) } + v, err, _ = m.sflight.Do(addr, func() (interface{}, error) { - conn, err := netpoll.DialConnection(network, addr, time.Second) - if err != nil { - return nil, err - } - trans = newTransport(clientTransport, sinfo, conn) - _ = conn.AddCloseCallback(func(connection netpoll.Connection) error { - // peer close - _ = trans.Close() - return nil - }) - m.pool.Store(addr, trans) - runtime.SetFinalizer(trans, func(trans *transport) { - // self close when not hold by user - _ = trans.Close() - }) - return trans, nil + transList := newMuxTransList(m.poolSize) + m.pool.Store(addr, transList) + return transList, nil }) if err != nil { return nil, err } - return v.(*transport), nil + return v.(*muxTransList).Get(sinfo, network, addr) } func (m *muxTransPool) Put(trans *transport) { diff --git a/pkg/streamx/provider/ttstream/transport.go b/pkg/streamx/provider/ttstream/transport.go index 30ce15ad19..59448b39ec 100644 --- a/pkg/streamx/provider/ttstream/transport.go +++ b/pkg/streamx/provider/ttstream/transport.go @@ -18,7 +18,6 @@ package ttstream import ( "context" - "encoding/binary" "errors" "fmt" "io" @@ -27,6 +26,7 @@ import ( "time" "github.com/bytedance/gopkg/lang/mcache" + "github.com/cloudwego/gopkg/bufiox" "github.com/cloudwego/gopkg/protocol/thrift" "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/serviceinfo" @@ -40,8 +40,8 @@ const ( serverTransport int32 = 2 streamCacheSize = 32 - frameCacheSize = 1024 - batchWriteSize = 32 + frameCacheSize = 32 + batchWriteSize = 8 ) func isIgnoreError(err error) bool { @@ -136,83 +136,74 @@ func (t *transport) loadStreamIO(sid int32) (sio *streamIO, ok bool) { return sio, true } -func (t *transport) loopRead() error { +func (t *transport) readFrame(reader bufiox.Reader) error { addr := t.conn.RemoteAddr().String() if t.kind == clientTransport { addr = t.conn.LocalAddr().String() } - for { - // decode frame - sizeBuf, err := t.conn.Reader().Peek(4) - if err != nil { - return err - } - size := binary.BigEndian.Uint32(sizeBuf) - slice, err := t.conn.Reader().Slice(int(size + 4)) - if err != nil { - return err - } - reader := newReaderBuffer(slice) - fr, err := DecodeFrame(context.Background(), reader) - if err != nil { - return err - } - klog.Debugf("transport[%d-%s] DecodeFrame: fr=%v", t.kind, addr, fr) + fr, err := DecodeFrame(context.Background(), reader) + if err != nil { + return err + } + klog.Debugf("transport[%d-%s] DecodeFrame: fr=%v", t.kind, addr, fr) - switch fr.typ { - case metaFrameType: - sio, ok := t.loadStreamIO(fr.sid) - if !ok { - klog.Errorf("transport[%d-%s] read a unknown stream meta: sid=%d", t.kind, addr, fr.sid) - continue - } + switch fr.typ { + case metaFrameType: + sio, ok := t.loadStreamIO(fr.sid) + if ok { err = sio.stream.readMetaFrame(fr.meta, fr.header, fr.payload) - if err != nil { - return err - } - case headerFrameType: - switch t.kind { - case serverTransport: - // Header Frame: server recv a new stream - smode := t.sinfo.MethodInfo(fr.method).StreamingMode() - s := newStream(t, smode, fr.streamFrame) - t.storeStreamIO(context.Background(), s) - t.spipe.Write(context.Background(), s) - case clientTransport: - // Header Frame: client recv header - sio, ok := t.loadStreamIO(fr.sid) - if !ok { - klog.Errorf("transport[%d-%s] read a unknown stream header: sid=%d header=%v", - t.kind, addr, fr.sid, fr.header) - continue - } - err = sio.stream.readHeader(fr.header) - if err != nil { - return err - } - } - case dataFrameType: - // Data Frame: decode and distribute data + } else { + klog.Errorf("transport[%d-%s] read a unknown stream meta: sid=%d", t.kind, addr, fr.sid) + } + case headerFrameType: + switch t.kind { + case serverTransport: + // Header Frame: server recv a new stream + smode := t.sinfo.MethodInfo(fr.method).StreamingMode() + s := newStream(t, smode, fr.streamFrame) + t.storeStreamIO(context.Background(), s) + err = t.spipe.Write(context.Background(), s) + case clientTransport: + // Header Frame: client recv header sio, ok := t.loadStreamIO(fr.sid) - if !ok { - klog.Errorf("transport[%d-%s] read a unknown stream data: sid=%d", t.kind, addr, fr.sid) - continue + if ok { + err = sio.stream.readHeader(fr.header) + } else { + klog.Errorf("transport[%d-%s] read a unknown stream header: sid=%d header=%v", + t.kind, addr, fr.sid, fr.header) } + } + case dataFrameType: + // Data Frame: decode and distribute data + sio, ok := t.loadStreamIO(fr.sid) + if ok { sio.input(context.Background(), fr) - case trailerFrameType: - // Trailer Frame: recv trailer, Close read direction - sio, ok := t.loadStreamIO(fr.sid) - if !ok { - // client recv an unknown trailer is in exception, - // because the client stream may already be GCed, - // but the connection is still active so peer server can send a trailer - klog.Errorf("transport[%d-%s] read a unknown stream trailer: sid=%d trailer=%v", - t.kind, addr, fr.sid, fr.trailer) - continue - } - if err = sio.stream.readTrailer(fr.trailer); err != nil { - return err - } + } else { + klog.Errorf("transport[%d-%s] read a unknown stream data: sid=%d", t.kind, addr, fr.sid) + } + case trailerFrameType: + // Trailer Frame: recv trailer, Close read direction + sio, ok := t.loadStreamIO(fr.sid) + if ok { + err = sio.stream.readTrailer(fr.trailer) + } else { + // client recv an unknown trailer is in exception, + // because the client stream may already be GCed, + // but the connection is still active so peer server can send a trailer + klog.Errorf("transport[%d-%s] read a unknown stream trailer: sid=%d trailer=%v", + t.kind, addr, fr.sid, fr.trailer) + } + } + return err +} + +func (t *transport) loopRead() error { + reader := newReaderBuffer(t.conn.Reader()) + for { + err := t.readFrame(reader) + // read frame return an un-recovered error, so we should close the transport + if err != nil { + return err } } } diff --git a/pkg/streamx/provider/ttstream/ttstream_client_test.go b/pkg/streamx/provider/ttstream/ttstream_client_test.go index 4aa3bbe9bb..d1c1e0eaa6 100644 --- a/pkg/streamx/provider/ttstream/ttstream_client_test.go +++ b/pkg/streamx/provider/ttstream/ttstream_client_test.go @@ -270,7 +270,7 @@ func TestTTHeaderStreaming(t *testing.T) { t.Logf("Client ClientStream CloseAndRecv: %v", res) atomic.AddInt32(&serverStreamCount, -1) waitServerStreamDone() - test.Assert(t, serverRecvCount == int32(round), serverRecvCount) + test.DeepEqual(t, serverRecvCount, int32(round)) test.Assert(t, serverSendCount == 1, serverSendCount) testHeaderAndTrailer(t, cs) cs = nil