Skip to content

Commit

Permalink
perf: mux trans pool
Browse files Browse the repository at this point in the history
  • Loading branch information
joway committed Oct 10, 2024
1 parent 9add164 commit 905b088
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 92 deletions.
74 changes: 54 additions & 20 deletions pkg/streamx/provider/ttstream/client_trans_pool_mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package ttstream
import (
"runtime"
"sync"
"sync/atomic"
"time"

"github.com/cloudwego/kitex/pkg/serviceinfo"
Expand All @@ -28,43 +29,76 @@ import (

var _ transPool = (*muxTransPool)(nil)

type muxTransList struct {
L sync.RWMutex
size int
cursor uint32
transports []*transport
}

func newMuxTransList(size int) *muxTransList {
tl := new(muxTransList)
tl.size = size
tl.transports = make([]*transport, size)
return tl
}

func (tl *muxTransList) Get(sinfo *serviceinfo.ServiceInfo, network string, addr string) (*transport, error) {
idx := atomic.AddUint32(&tl.cursor, 1) % uint32(tl.size)
tl.L.RLock()
trans := tl.transports[idx]
tl.L.RUnlock()
if trans != nil && trans.IsActive() {
return trans, nil
}

conn, err := netpoll.DialConnection(network, addr, time.Second)

Check failure on line 55 in pkg/streamx/provider/ttstream/client_trans_pool_mux.go

View workflow job for this annotation

GitHub Actions / windows-test

undefined: netpoll.DialConnection
if err != nil {
return nil, err
}
trans = newTransport(clientTransport, sinfo, conn)
_ = conn.AddCloseCallback(func(connection netpoll.Connection) error {
// peer close
_ = trans.Close()
return nil
})
runtime.SetFinalizer(trans, func(trans *transport) {
// self close when not hold by user
_ = trans.Close()
})
tl.L.Lock()
tl.transports[idx] = trans
tl.L.Unlock()
return trans, nil
}

func newMuxTransPool() transPool {
t := new(muxTransPool)
t.poolSize = runtime.GOMAXPROCS(0)
return t
}

type muxTransPool struct {
pool sync.Map // addr:*transport
sflight singleflight.Group
poolSize int
pool sync.Map // addr:*muxTransList
sflight singleflight.Group
}

func (m *muxTransPool) Get(sinfo *serviceinfo.ServiceInfo, network string, addr string) (trans *transport, err error) {
v, ok := m.pool.Load(addr)
if ok {
return v.(*transport), nil
return v.(*muxTransList).Get(sinfo, network, addr)
}

v, err, _ = m.sflight.Do(addr, func() (interface{}, error) {
conn, err := netpoll.DialConnection(network, addr, time.Second)
if err != nil {
return nil, err
}
trans = newTransport(clientTransport, sinfo, conn)
_ = conn.AddCloseCallback(func(connection netpoll.Connection) error {
// peer close
_ = trans.Close()
return nil
})
m.pool.Store(addr, trans)
runtime.SetFinalizer(trans, func(trans *transport) {
// self close when not hold by user
_ = trans.Close()
})
return trans, nil
transList := newMuxTransList(m.poolSize)
m.pool.Store(addr, transList)
return transList, nil
})
if err != nil {
return nil, err
}
return v.(*transport), nil
return v.(*muxTransList).Get(sinfo, network, addr)
}

func (m *muxTransPool) Put(trans *transport) {
Expand Down
133 changes: 62 additions & 71 deletions pkg/streamx/provider/ttstream/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package ttstream

import (
"context"
"encoding/binary"
"errors"
"fmt"
"io"
Expand All @@ -27,6 +26,7 @@ import (
"time"

"github.com/bytedance/gopkg/lang/mcache"
"github.com/cloudwego/gopkg/bufiox"
"github.com/cloudwego/gopkg/protocol/thrift"
"github.com/cloudwego/kitex/pkg/klog"
"github.com/cloudwego/kitex/pkg/serviceinfo"
Expand All @@ -40,8 +40,8 @@ const (
serverTransport int32 = 2

streamCacheSize = 32
frameCacheSize = 1024
batchWriteSize = 32
frameCacheSize = 32
batchWriteSize = 8
)

func isIgnoreError(err error) bool {
Expand Down Expand Up @@ -136,83 +136,74 @@ func (t *transport) loadStreamIO(sid int32) (sio *streamIO, ok bool) {
return sio, true
}

func (t *transport) loopRead() error {
func (t *transport) readFrame(reader bufiox.Reader) error {
addr := t.conn.RemoteAddr().String()
if t.kind == clientTransport {
addr = t.conn.LocalAddr().String()
}
for {
// decode frame
sizeBuf, err := t.conn.Reader().Peek(4)
if err != nil {
return err
}
size := binary.BigEndian.Uint32(sizeBuf)
slice, err := t.conn.Reader().Slice(int(size + 4))
if err != nil {
return err
}
reader := newReaderBuffer(slice)
fr, err := DecodeFrame(context.Background(), reader)
if err != nil {
return err
}
klog.Debugf("transport[%d-%s] DecodeFrame: fr=%v", t.kind, addr, fr)
fr, err := DecodeFrame(context.Background(), reader)
if err != nil {
return err
}
klog.Debugf("transport[%d-%s] DecodeFrame: fr=%v", t.kind, addr, fr)

switch fr.typ {
case metaFrameType:
sio, ok := t.loadStreamIO(fr.sid)
if !ok {
klog.Errorf("transport[%d-%s] read a unknown stream meta: sid=%d", t.kind, addr, fr.sid)
continue
}
switch fr.typ {
case metaFrameType:
sio, ok := t.loadStreamIO(fr.sid)
if ok {
err = sio.stream.readMetaFrame(fr.meta, fr.header, fr.payload)
if err != nil {
return err
}
case headerFrameType:
switch t.kind {
case serverTransport:
// Header Frame: server recv a new stream
smode := t.sinfo.MethodInfo(fr.method).StreamingMode()
s := newStream(t, smode, fr.streamFrame)
t.storeStreamIO(context.Background(), s)
t.spipe.Write(context.Background(), s)
case clientTransport:
// Header Frame: client recv header
sio, ok := t.loadStreamIO(fr.sid)
if !ok {
klog.Errorf("transport[%d-%s] read a unknown stream header: sid=%d header=%v",
t.kind, addr, fr.sid, fr.header)
continue
}
err = sio.stream.readHeader(fr.header)
if err != nil {
return err
}
}
case dataFrameType:
// Data Frame: decode and distribute data
} else {
klog.Errorf("transport[%d-%s] read a unknown stream meta: sid=%d", t.kind, addr, fr.sid)
}
case headerFrameType:
switch t.kind {
case serverTransport:
// Header Frame: server recv a new stream
smode := t.sinfo.MethodInfo(fr.method).StreamingMode()
s := newStream(t, smode, fr.streamFrame)
t.storeStreamIO(context.Background(), s)
err = t.spipe.Write(context.Background(), s)
case clientTransport:
// Header Frame: client recv header
sio, ok := t.loadStreamIO(fr.sid)
if !ok {
klog.Errorf("transport[%d-%s] read a unknown stream data: sid=%d", t.kind, addr, fr.sid)
continue
if ok {
err = sio.stream.readHeader(fr.header)
} else {
klog.Errorf("transport[%d-%s] read a unknown stream header: sid=%d header=%v",
t.kind, addr, fr.sid, fr.header)
}
}
case dataFrameType:
// Data Frame: decode and distribute data
sio, ok := t.loadStreamIO(fr.sid)
if ok {
sio.input(context.Background(), fr)
case trailerFrameType:
// Trailer Frame: recv trailer, Close read direction
sio, ok := t.loadStreamIO(fr.sid)
if !ok {
// client recv an unknown trailer is in exception,
// because the client stream may already be GCed,
// but the connection is still active so peer server can send a trailer
klog.Errorf("transport[%d-%s] read a unknown stream trailer: sid=%d trailer=%v",
t.kind, addr, fr.sid, fr.trailer)
continue
}
if err = sio.stream.readTrailer(fr.trailer); err != nil {
return err
}
} else {
klog.Errorf("transport[%d-%s] read a unknown stream data: sid=%d", t.kind, addr, fr.sid)
}
case trailerFrameType:
// Trailer Frame: recv trailer, Close read direction
sio, ok := t.loadStreamIO(fr.sid)
if ok {
err = sio.stream.readTrailer(fr.trailer)
} else {
// client recv an unknown trailer is in exception,
// because the client stream may already be GCed,
// but the connection is still active so peer server can send a trailer
klog.Errorf("transport[%d-%s] read a unknown stream trailer: sid=%d trailer=%v",
t.kind, addr, fr.sid, fr.trailer)
}
}
return err
}

func (t *transport) loopRead() error {
reader := newReaderBuffer(t.conn.Reader())
for {
err := t.readFrame(reader)
// read frame return an un-recovered error, so we should close the transport
if err != nil {
return err
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/streamx/provider/ttstream/ttstream_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ func TestTTHeaderStreaming(t *testing.T) {
t.Logf("Client ClientStream CloseAndRecv: %v", res)
atomic.AddInt32(&serverStreamCount, -1)
waitServerStreamDone()
test.Assert(t, serverRecvCount == int32(round), serverRecvCount)
test.DeepEqual(t, serverRecvCount, int32(round))
test.Assert(t, serverSendCount == 1, serverSendCount)
testHeaderAndTrailer(t, cs)
cs = nil
Expand Down

0 comments on commit 905b088

Please sign in to comment.