diff --git a/pkg/streamx/provider/ttstream/client_provier.go b/pkg/streamx/provider/ttstream/client_provier.go index 3d439b944e..0762bf30f6 100644 --- a/pkg/streamx/provider/ttstream/client_provier.go +++ b/pkg/streamx/provider/ttstream/client_provier.go @@ -19,6 +19,7 @@ package ttstream import ( "context" "runtime" + "sync/atomic" "github.com/bytedance/gopkg/cloud/metainfo" "github.com/cloudwego/gopkg/protocol/ttheader" @@ -78,24 +79,44 @@ func (c clientProvider) NewStream(ctx context.Context, ri rpcinfo.RPCInfo, callO return nil, err } - s, err := trans.newStream(ctx, method, intHeader, strHeader) + sio, err := trans.newStreamIO(ctx, method, intHeader, strHeader) if err != nil { return nil, err } - s.setRecvTimeout(rconfig.StreamRecvTimeout()) + sio.stream.setRecvTimeout(rconfig.StreamRecvTimeout()) // only client can set meta frame handler - s.setMetaFrameHandler(c.metaHandler) + sio.stream.setMetaFrameHandler(c.metaHandler) // if ctx from server side, we should cancel the stream when server handler already returned // TODO: this canceling transmit should be configurable ktx.RegisterCancelCallback(ctx, func() { - s.cancel() + sio.stream.cancel() }) - cs := newClientStream(s) + cs := newClientStream(sio.stream) + // the END of a client stream means it should send and recv trailer and not hold by user anymore + var ended uint32 + sio.setEOFCallback(func() { + // if stream is ended by both parties, put the transport back to pool + sio.stream.close() + if atomic.AddUint32(&ended, 1) == 2 { + if trans.IsActive() { + c.transPool.Put(trans) + } + err = trans.streamDelete(sio.stream.sid) + } + }) runtime.SetFinalizer(cs, func(cstream *clientStream) { - _ = cstream.close() - c.transPool.Put(trans) + // it's safe to call CloseSend twice + // we do repeated CloseSend here to ensure stream can be closed normally + _ = cstream.CloseSend(ctx) + // only delete stream when clientStream be finalized + if atomic.AddUint32(&ended, 1) == 2 { + if trans.IsActive() { + c.transPool.Put(trans) + } + err = trans.streamDelete(sio.stream.sid) + } }) return cs, err } diff --git a/pkg/streamx/provider/ttstream/client_trans_pool_longconn.go b/pkg/streamx/provider/ttstream/client_trans_pool_longconn.go index cf3724332f..9e55aef78b 100644 --- a/pkg/streamx/provider/ttstream/client_trans_pool_longconn.go +++ b/pkg/streamx/provider/ttstream/client_trans_pool_longconn.go @@ -50,9 +50,15 @@ type longConnTransPool struct { } func (c *longConnTransPool) Get(sinfo *serviceinfo.ServiceInfo, network string, addr string) (trans *transport, err error) { - o := c.transPool.Pop(addr) - if o != nil { - return o.(*transport), nil + for { + o := c.transPool.Pop(addr) + if o == nil { + break + } + trans = o.(*transport) + if trans.IsActive() { + return trans, nil + } } // create new connection @@ -60,13 +66,8 @@ func (c *longConnTransPool) Get(sinfo *serviceinfo.ServiceInfo, network string, if err != nil { return nil, err } - // create new transport trans = newTransport(clientTransport, sinfo, conn) - _ = conn.AddCloseCallback(func(connection netpoll.Connection) error { - _ = trans.Close() - return nil - }) - runtime.SetFinalizer(trans, func(t *transport) { t.Close() }) + // create new transport return trans, nil } diff --git a/pkg/streamx/provider/ttstream/container/object_pool.go b/pkg/streamx/provider/ttstream/container/object_pool.go index c2dd4090dd..26fdc26978 100644 --- a/pkg/streamx/provider/ttstream/container/object_pool.go +++ b/pkg/streamx/provider/ttstream/container/object_pool.go @@ -68,11 +68,16 @@ func (s *ObjectPool) cleaning() { now := time.Now() s.L.Lock() - objSize := len(s.objects) + // update cleanInternal + objSize := 0 + for _, stk := range s.objects { + objSize += stk.Size() + } cleanInternal = time.Second + time.Duration(objSize)*time.Millisecond*10 if cleanInternal > time.Second*10 { cleanInternal = time.Second * 10 } + // clean objects for key, stk := range s.objects { deleted := 0 var oldest *time.Time @@ -97,7 +102,7 @@ func (s *ObjectPool) cleaning() { return true, true }) if oldest != nil { - klog.Infof("object[%s] pool delete %d objects", key, deleted) + klog.Infof("object[%s] pool deleted %d objects, oldest=%s", key, deleted, oldest.String()) } } s.L.Unlock() diff --git a/pkg/streamx/provider/ttstream/container/stack.go b/pkg/streamx/provider/ttstream/container/stack.go index d076257331..7083fc27e0 100644 --- a/pkg/streamx/provider/ttstream/container/stack.go +++ b/pkg/streamx/provider/ttstream/container/stack.go @@ -43,6 +43,7 @@ func (s *Stack[ValueType]) Size() (size int) { func (s *Stack[ValueType]) RangeDelete(checking func(v ValueType) (deleteNode bool, continueRange bool)) { // Stop the world! s.L.Lock() + // range from the stack bottom(oldest item) node := s.head deleteNode := false continueRange := true @@ -56,10 +57,7 @@ func (s *Stack[ValueType]) RangeDelete(checking func(v ValueType) (deleteNode bo last := node.last next := node.next // modify last node - if last == nil { - // cur node is head node - s.head = next - } else { + if last != nil { // change last next ptr last.next = next } @@ -67,7 +65,11 @@ func (s *Stack[ValueType]) RangeDelete(checking func(v ValueType) (deleteNode bo if next != nil { next.last = last } - if node == s.tail { + // modify link list + if s.head == node { + s.head = next + } + if s.tail == node { s.tail = last } node = node.next @@ -76,24 +78,31 @@ func (s *Stack[ValueType]) RangeDelete(checking func(v ValueType) (deleteNode bo s.L.Unlock() } -func (s *Stack[ValueType]) Pop() (value ValueType, ok bool) { - var node *doubleLinkNode[ValueType] - s.L.Lock() +func (s *Stack[ValueType]) pop() (node *doubleLinkNode[ValueType]) { if s.tail == nil { - s.L.Unlock() - return value, false + return nil } node = s.tail if node.last == nil { - // first node + // if node is the only node in the list, clear the whole linklist s.head = nil s.tail = nil } else { - node.last.next = nil + // if node is not the only node in the list, only modify the list's tail s.tail = node.last + s.tail.next = nil } s.size-- + return node +} + +func (s *Stack[ValueType]) Pop() (value ValueType, ok bool) { + s.L.Lock() + node := s.pop() s.L.Unlock() + if node == nil { + return value, false + } value = node.val node.reset() @@ -101,23 +110,32 @@ func (s *Stack[ValueType]) Pop() (value ValueType, ok bool) { return value, true } -func (s *Stack[ValueType]) PopBottom() (value ValueType, ok bool) { - var node *doubleLinkNode[ValueType] - s.L.Lock() +func (s *Stack[ValueType]) popBottom() (node *doubleLinkNode[ValueType]) { if s.head == nil { - s.L.Unlock() - return value, false + return nil } node = s.head - s.head = s.head.next - if s.head != nil { - s.head.last = nil - } - if s.tail == node { + if node.next == nil { + // if node is the only node in the list, clear the whole linklist + s.head = nil s.tail = nil + } else { + // if node is not the only node in the list, only modify the list's head + s.head = s.head.next + s.head.last = nil } s.size-- + return node +} + +func (s *Stack[ValueType]) PopBottom() (value ValueType, ok bool) { + s.L.Lock() + node := s.popBottom() s.L.Unlock() + if node == nil { + return value, false + } + value = node.val node.reset() s.nodePool.Put(node) @@ -142,6 +160,7 @@ func (s *Stack[ValueType]) Push(value ValueType) { s.head = node s.tail = node } else { + // not first node node.last = s.tail s.tail.next = node s.tail = node diff --git a/pkg/streamx/provider/ttstream/metadata.go b/pkg/streamx/provider/ttstream/metadata.go index 81f15ebcd2..fef0b03789 100644 --- a/pkg/streamx/provider/ttstream/metadata.go +++ b/pkg/streamx/provider/ttstream/metadata.go @@ -22,7 +22,10 @@ import ( "github.com/cloudwego/kitex/pkg/streamx" ) -var ErrInvalidStreamKind = errors.New("invalid stream kind") +var ( + ErrInvalidStreamKind = errors.New("invalid stream kind") + ErrClosedStream = errors.New("stream is closed") +) // only for meta frame handler type IntHeader map[uint16]string diff --git a/pkg/streamx/provider/ttstream/stream.go b/pkg/streamx/provider/ttstream/stream.go index aee917318c..7c5eaf841e 100644 --- a/pkg/streamx/provider/ttstream/stream.go +++ b/pkg/streamx/provider/ttstream/stream.go @@ -39,15 +39,16 @@ var ( _ StreamMeta = (*stream)(nil) ) -func newStream(ctx context.Context, trans *transport, mode streamx.StreamingMode, smeta streamFrame) (s *stream) { - s = new(stream) +func newStream(trans *transport, mode streamx.StreamingMode, smeta streamFrame) *stream { + s := new(stream) s.streamFrame = smeta s.trans = trans s.mode = mode - s.headerSig = make(chan struct{}) - s.trailerSig = make(chan struct{}) + s.wheader = make(streamx.Header) + s.wtrailer = make(streamx.Trailer) + s.headerSig = make(chan int32, 1) + s.trailerSig = make(chan int32, 1) s.StreamMeta = newStreamMeta() - trans.storeStreamIO(ctx, s) return s } @@ -58,17 +59,23 @@ type streamFrame struct { trailer streamx.Trailer } +const ( + streamSigNone int32 = 0 + streamSigActive int32 = 1 + streamSigInactive int32 = -1 +) + type stream struct { streamFrame trans *transport mode streamx.StreamingMode - wheader streamx.Header - wtrailer streamx.Trailer + wheader streamx.Header // wheader == nil means it already be sent + wtrailer streamx.Trailer // wtrailer == nil means it already be sent selfEOF int32 peerEOF int32 - headerSig chan struct{} - trailerSig chan struct{} - err error + headerSig chan int32 + trailerSig chan int32 + err error StreamMeta metaHandler MetaFrameHandler @@ -92,14 +99,12 @@ func (s *stream) Method() string { func (s *stream) close() { select { - case <-s.headerSig: + case s.headerSig <- streamSigInactive: default: - close(s.headerSig) } select { - case <-s.trailerSig: + case s.trailerSig <- streamSigInactive: default: - close(s.trailerSig) } } @@ -117,10 +122,9 @@ func (s *stream) readMetaFrame(intHeader IntHeader, header streamx.Header, paylo func (s *stream) readHeader(hd streamx.Header) (err error) { s.header = hd select { - case <-s.headerSig: - return errors.New("already set header") + case s.headerSig <- streamSigActive: default: - close(s.headerSig) + return fmt.Errorf("stream[%d] already set header", s.sid) } klog.Debugf("stream[%s] read header: %v", s.method, hd) return nil @@ -128,23 +132,29 @@ func (s *stream) readHeader(hd streamx.Header) (err error) { // setHeader use the hd as the underlying header func (s *stream) setHeader(hd streamx.Header) { - s.wheader = hd + if hd != nil { + s.wheader = hd + } return } // writeHeader copy kvs into s.wheader -func (s *stream) writeHeader(hd streamx.Header) { +func (s *stream) writeHeader(hd streamx.Header) error { if s.wheader == nil { - s.wheader = make(streamx.Header) + return fmt.Errorf("stream header already sent") } for k, v := range hd { s.wheader[k] = v } + return nil } func (s *stream) sendHeader() (err error) { wheader := s.wheader s.wheader = nil + if wheader == nil { + return fmt.Errorf("stream header already sent") + } err = s.trans.streamSendHeader(s.sid, s.method, wheader) return err } @@ -153,7 +163,7 @@ func (s *stream) sendHeader() (err error) { // readTrailer by server: unblock recv function and return EOF if no unread frame func (s *stream) readTrailerFrame(fr *Frame) (err error) { if !atomic.CompareAndSwapInt32(&s.peerEOF, 0, 1) { - return nil + return fmt.Errorf("stream read a unexcept trailer") } if len(fr.payload) > 0 { @@ -169,10 +179,14 @@ func (s *stream) readTrailerFrame(fr *Frame) (err error) { } s.trailer = fr.trailer select { - case <-s.trailerSig: + case s.trailerSig <- streamSigActive: + default: return errors.New("already set trailer") + } + select { + case s.headerSig <- streamSigNone: + // if trailer arrived, we should return unblock stream.Header() default: - close(s.trailerSig) } klog.Debugf("stream[%d] recv trailer: %v, err: %v", s.sid, s.trailer, s.err) @@ -181,7 +195,7 @@ func (s *stream) readTrailerFrame(fr *Frame) (err error) { func (s *stream) writeTrailer(tl streamx.Trailer) (err error) { if s.wtrailer == nil { - s.wtrailer = make(streamx.Trailer) + return fmt.Errorf("stream trailer already sent") } for k, v := range tl { s.wtrailer[k] = v @@ -211,8 +225,13 @@ func (s *stream) sendTrailer(ctx context.Context, ex tException) (err error) { if !atomic.CompareAndSwapInt32(&s.selfEOF, 0, 1) { return nil } - klog.Debugf("stream[%d] send trialer", s.sid) - return s.trans.streamCloseSend(s.sid, s.method, s.wtrailer, ex) + wtrailer := s.wtrailer + s.wtrailer = nil + if wtrailer == nil { + return fmt.Errorf("stream trailer already sent") + } + klog.Debugf("transport[%d]-stream[%d] send trialer", s.trans.kind, s.sid) + return s.trans.streamCloseSend(s.sid, s.method, wtrailer, ex) } func (s *stream) finished() bool { @@ -262,18 +281,6 @@ func (s *clientStream) CloseSend(ctx context.Context) error { return s.sendTrailer(ctx, nil) } -// after close stream cannot be access again -func (s *clientStream) close() error { - // client should CloseSend first and then close stream - err := s.sendTrailer(context.Background(), nil) - if err != nil { - return err - } - err = s.trans.streamClose(s.sid) - s.stream.close() - return err -} - func newServerStream(s *stream) streamx.ServerStream { ss := &serverStream{stream: s} return ss @@ -305,7 +312,7 @@ func (s *serverStream) close(ex tException) error { if err != nil { return err } - err = s.trans.streamClose(s.sid) + err = s.trans.streamDelete(s.sid) s.stream.close() return err } diff --git a/pkg/streamx/provider/ttstream/stream_header_trailer.go b/pkg/streamx/provider/ttstream/stream_header_trailer.go index e773e9123a..beedb24f41 100644 --- a/pkg/streamx/provider/ttstream/stream_header_trailer.go +++ b/pkg/streamx/provider/ttstream/stream_header_trailer.go @@ -16,28 +16,49 @@ package ttstream -import "github.com/cloudwego/kitex/pkg/streamx" +import ( + "errors" + + "github.com/cloudwego/kitex/pkg/streamx" +) var _ ClientStreamMeta = (*clientStream)(nil) var _ ServerStreamMeta = (*serverStream)(nil) func (s *clientStream) Header() (streamx.Header, error) { - <-s.headerSig - return s.header, nil + sig := <-s.headerSig + switch sig { + case streamSigActive: + return s.header, nil + case streamSigNone: + return make(streamx.Header), nil + case streamSigInactive: + return nil, ErrClosedStream + } + return nil, errors.New("invalid stream signal") } func (s *clientStream) Trailer() (streamx.Trailer, error) { - <-s.trailerSig - return s.trailer, nil + sig := <-s.trailerSig + switch sig { + case streamSigActive: + return s.trailer, nil + case streamSigNone: + return make(streamx.Trailer), nil + case streamSigInactive: + return nil, ErrClosedStream + } + return nil, errors.New("invalid stream signal") } func (s *serverStream) SetHeader(hd streamx.Header) error { - s.writeHeader(hd) - return nil + return s.writeHeader(hd) } func (s *serverStream) SendHeader(hd streamx.Header) error { - s.writeHeader(hd) + if err := s.writeHeader(hd); err != nil { + return err + } return s.stream.sendHeader() } diff --git a/pkg/streamx/provider/ttstream/stream_io.go b/pkg/streamx/provider/ttstream/stream_io.go index d82e27058b..868425dd4a 100644 --- a/pkg/streamx/provider/ttstream/stream_io.go +++ b/pkg/streamx/provider/ttstream/stream_io.go @@ -20,6 +20,7 @@ import ( "context" "errors" "io" + "sync/atomic" "github.com/cloudwego/kitex/pkg/klog" "github.com/cloudwego/kitex/pkg/streamx/provider/ttstream/container" @@ -29,9 +30,14 @@ type streamIO struct { ctx context.Context trigger chan struct{} stream *stream - fpipe *container.Pipe[*Frame] - fcache [1]*Frame - err error + // eofFlag == 2 when both parties send trailers + eofFlag int32 + // eofCallback will be called when eofFlag == 2 + // eofCallback will not be called if stream is not be ended in a normal way + eofCallback func() + fpipe *container.Pipe[*Frame] + fcache [1]*Frame + err error } func newStreamIO(ctx context.Context, s *stream) *streamIO { @@ -43,6 +49,10 @@ func newStreamIO(ctx context.Context, s *stream) *streamIO { return sio } +func (s *streamIO) setEOFCallback(f func()) { + s.eofCallback = f +} + func (s *streamIO) input(ctx context.Context, f *Frame) { err := s.fpipe.Write(ctx, f) if err != nil { @@ -76,9 +86,15 @@ func (s *streamIO) output(ctx context.Context) (f *Frame, err error) { func (s *streamIO) closeRecv() { s.fpipe.Close() + if atomic.AddInt32(&s.eofFlag, 1) == 2 && s.eofCallback != nil { + s.eofCallback() + } } func (s *streamIO) closeSend() { + if atomic.AddInt32(&s.eofFlag, 1) == 2 && s.eofCallback != nil { + s.eofCallback() + } } func (s *streamIO) cancel() { diff --git a/pkg/streamx/provider/ttstream/transport.go b/pkg/streamx/provider/ttstream/transport.go index 2e0eb6e23a..ab7825dfec 100644 --- a/pkg/streamx/provider/ttstream/transport.go +++ b/pkg/streamx/provider/ttstream/transport.go @@ -44,7 +44,9 @@ const ( frameChanSize = 32 ) -var transIgnoreError = errors.Join(netpoll.ErrEOF, io.EOF, netpoll.ErrConnClosed) +func isIgnoreError(err error) bool { + return errors.Is(err, netpoll.ErrEOF) || errors.Is(err, io.EOF) || errors.Is(err, netpoll.ErrConnClosed) +} type transport struct { kind int32 @@ -60,7 +62,9 @@ type transport struct { } func newTransport(kind int32, sinfo *serviceinfo.ServiceInfo, conn netpoll.Connection) *transport { - _ = conn.SetDeadline(time.Now().Add(time.Hour)) + // stream max idle session is 10 minutes. + // TODO: let it configurable + _ = conn.SetReadTimeout(time.Minute * 10) t := &transport{ kind: kind, sinfo: sinfo, @@ -74,21 +78,21 @@ func newTransport(kind int32, sinfo *serviceinfo.ServiceInfo, conn netpoll.Conne go func() { err := t.loopRead() if err != nil { - if !errors.Is(err, transIgnoreError) { + if !isIgnoreError(err) { klog.Warnf("transport[%d] loop read err: %v", t.kind, err) } + // if connection is closed by peer, loop read should return ErrConnClosed error, + // so we should close transport here _ = t.Close() } }() go func() { err := t.loopWrite() if err != nil { - if !errors.Is(err, transIgnoreError) { + if !isIgnoreError(err) { klog.Warnf("transport[%d] loop write err: %v", t.kind, err) } _ = t.Close() - // because loopWrite function return, we should close conn actively - _ = t.conn.Close() } }() return t @@ -106,14 +110,21 @@ func (t *transport) Close() (err error) { t.spipe.Close() t.streams.Range(func(key, value any) bool { sio := value.(*streamIO) - _ = t.streamClose(sio.stream.sid) + sio.stream.close() + _ = t.streamDelete(sio.stream.sid) return true }) return err } -func (t *transport) storeStreamIO(ctx context.Context, s *stream) { - t.streams.Store(s.sid, newStreamIO(ctx, s)) +func (t *transport) IsActive() bool { + return atomic.LoadInt32(&t.closedFlag) == 0 && t.conn.IsActive() +} + +func (t *transport) storeStreamIO(ctx context.Context, s *stream) *streamIO { + sio := newStreamIO(ctx, s) + t.streams.Store(s.sid, sio) + return sio } func (t *transport) loadStreamIO(sid int32) (sio *streamIO, ok bool) { @@ -126,6 +137,10 @@ func (t *transport) loadStreamIO(sid int32) (sio *streamIO, ok bool) { } func (t *transport) loopRead() error { + addr := t.conn.RemoteAddr().String() + if t.kind == clientTransport { + addr = t.conn.LocalAddr().String() + } for { // decode frame sizeBuf, err := t.conn.Reader().Peek(4) @@ -142,13 +157,13 @@ func (t *transport) loopRead() error { if err != nil { return err } - klog.Debugf("transport[%d] DecodeFrame: fr=%v", t.kind, fr) + klog.Debugf("transport[%d-%s] DecodeFrame: fr=%v", t.kind, addr, fr) switch fr.typ { case metaFrameType: sio, ok := t.loadStreamIO(fr.sid) if !ok { - klog.Errorf("transport[%d] read a unknown stream meta: sid=%d", t.kind, fr.sid) + klog.Errorf("transport[%d-%s] read a unknown stream meta: sid=%d", t.kind, addr, fr.sid) continue } err = sio.stream.readMetaFrame(fr.meta, fr.header, fr.payload) @@ -160,14 +175,15 @@ func (t *transport) loopRead() error { case serverTransport: // Header Frame: server recv a new stream smode := t.sinfo.MethodInfo(fr.method).StreamingMode() - s := newStream(context.Background(), t, smode, fr.streamFrame) - klog.Debugf("transport[%d] read a new stream: sid=%d", t.kind, s.sid) + s := newStream(t, smode, fr.streamFrame) + t.storeStreamIO(context.Background(), s) t.spipe.Write(context.Background(), s) case clientTransport: // Header Frame: client recv header sio, ok := t.loadStreamIO(fr.sid) if !ok { - klog.Errorf("transport[%d] read a unknown stream header: sid=%d", t.kind, fr.sid) + klog.Errorf("transport[%d-%s] read a unknown stream header: sid=%d header=%v", + t.kind, addr, fr.sid, fr.header) continue } err = sio.stream.readHeader(fr.header) @@ -179,7 +195,7 @@ func (t *transport) loopRead() error { // Data Frame: decode and distribute data sio, ok := t.loadStreamIO(fr.sid) if !ok { - klog.Errorf("transport[%d] read a unknown stream data: sid=%d", t.kind, fr.sid) + klog.Errorf("transport[%d-%s] read a unknown stream data: sid=%d", t.kind, addr, fr.sid) continue } sio.input(context.Background(), fr) @@ -190,7 +206,8 @@ func (t *transport) loopRead() error { // client recv an unknown trailer is in exception, // because the client stream may already be GCed, // but the connection is still active so peer server can send a trailer - klog.Debugf("transport[%d] read a unknown stream trailer: sid=%d", t.kind, fr.sid) + klog.Errorf("transport[%d-%s] read a unknown stream trailer: sid=%d trailer=%v", + t.kind, addr, fr.sid, fr.trailer) continue } if err = sio.stream.readTrailerFrame(fr); err != nil { @@ -330,9 +347,10 @@ func (t *transport) streamCancel(s *stream) (err error) { return nil } -func (t *transport) streamClose(sid int32) (err error) { +func (t *transport) streamDelete(sid int32) (err error) { // remove stream from transport - if _, ok := t.streams.LoadAndDelete(sid); !ok { + _, ok := t.streams.LoadAndDelete(sid) + if !ok { return nil } atomic.AddInt32(&t.streamingFlag, -1) @@ -345,11 +363,11 @@ func (t *transport) IsStreaming() bool { var clientStreamID int32 -// newStream create new stream on current connection +// newStreamIO create new stream on current connection // it's typically used by client side -// newStream is concurrency safe -func (t *transport) newStream( - ctx context.Context, method string, intHeader IntHeader, strHeader streamx.Header) (*stream, error) { +// newStreamIO is concurrency safe +func (t *transport) newStreamIO( + ctx context.Context, method string, intHeader IntHeader, strHeader streamx.Header) (*streamIO, error) { if t.kind != clientTransport { return nil, fmt.Errorf("transport already be used as other kind") } @@ -364,9 +382,10 @@ func (t *transport) newStream( if err != nil { return nil, err } - s := newStream(ctx, t, smode, streamFrame{sid: sid, method: method}) + s := newStream(t, smode, streamFrame{sid: sid, method: method}) + sio := t.storeStreamIO(ctx, s) atomic.AddInt32(&t.streamingFlag, 1) - return s, nil + return sio, nil } // readStream wait for a new incoming stream on current connection diff --git a/pkg/streamx/provider/ttstream/transport_test.go b/pkg/streamx/provider/ttstream/transport_test.go index 3a1d93bd8d..729d49b121 100644 --- a/pkg/streamx/provider/ttstream/transport_test.go +++ b/pkg/streamx/provider/ttstream/transport_test.go @@ -145,10 +145,10 @@ func TestTransport(t *testing.T) { defer wg.Done() // send header - s, err := trans.newStream(ctx, method, IntHeader{}, map[string]string{}) + sio, err := trans.newStreamIO(ctx, method, IntHeader{}, map[string]string{}) test.Assert(t, err == nil, err) - cs := newClientStream(s) + cs := newClientStream(sio.stream) t.Logf("client stream[%d] created", sid) // recv header diff --git a/pkg/streamx/provider/ttstream/ttstream_client_test.go b/pkg/streamx/provider/ttstream/ttstream_client_test.go index 2160a49192..10c1e96ede 100644 --- a/pkg/streamx/provider/ttstream/ttstream_client_test.go +++ b/pkg/streamx/provider/ttstream/ttstream_client_test.go @@ -300,7 +300,7 @@ func TestTTHeaderStreaming(t *testing.T) { atomic.AddInt32(&serverStreamCount, -1) waitServerStreamDone() test.Assert(t, serverRecvCount == 1, serverRecvCount) - test.Assert(t, serverSendCount == int32(received), serverSendCount) + test.Assert(t, serverSendCount == int32(received), serverSendCount, received) testHeaderAndTrailer(t, ss) ss = nil serverRecvCount = 0