diff --git a/pkg/remote/trans/netpoll/bytebuf.go b/pkg/remote/trans/netpoll/bytebuf.go index 43825d3da2..fd644902da 100644 --- a/pkg/remote/trans/netpoll/bytebuf.go +++ b/pkg/remote/trans/netpoll/bytebuf.go @@ -18,6 +18,7 @@ package netpoll import ( "errors" + "io" "sync" "github.com/cloudwego/netpoll" @@ -35,6 +36,7 @@ func init() { func NewReaderByteBuffer(r netpoll.Reader) remote.ByteBuffer { bytebuf := bytebufPool.Get().(*netpollByteBuffer) bytebuf.reader = r + bytebuf.rd, _ = r.(io.Reader) bytebuf.status = remote.BitReadable bytebuf.readSize = 0 return bytebuf @@ -53,6 +55,7 @@ func NewReaderWriterByteBuffer(rw netpoll.ReadWriter) remote.ByteBuffer { bytebuf := bytebufPool.Get().(*netpollByteBuffer) bytebuf.writer = rw bytebuf.reader = rw + bytebuf.rd, _ = rw.(io.Reader) bytebuf.status = remote.BitWritable | remote.BitReadable return bytebuf } @@ -66,6 +69,8 @@ type netpollByteBuffer struct { reader netpoll.Reader status int readSize int + + rd io.Reader // from reader for Read } var _ remote.ByteBuffer = &netpollByteBuffer{} @@ -107,11 +112,32 @@ func (b *netpollByteBuffer) ReadableLen() (n int) { } // Read implement io.Reader -func (b *netpollByteBuffer) Read(p []byte) (n int, err error) { +func (b *netpollByteBuffer) Read(p []byte) (int, error) { if b.status&remote.BitReadable == 0 { return -1, errors.New("unreadable buffer, cannot support Read") } - rb, err := b.reader.Next(len(p)) + if b.rd != nil { + // use io.Reader if implemented, it works for netpoll connection. + // but for netpoll.Reader, it doesn't guarantee that + // the underlying implementation has Read method of io.Reader + return b.rd.Read(p) + } + + // make sure we have at least one byte to read, + // or Next call may block till timeout + m := b.reader.Len() + if m == 0 { + _, err := b.reader.Peek(1) + if err != nil { + return 0, err + } + m = b.reader.Len() // must >= 1 + } + n := len(p) + if n > m { + n = m + } + rb, err := b.reader.Next(n) b.readSize += len(rb) return copy(p, rb), err } diff --git a/pkg/remote/trans/netpoll/bytebuf_test.go b/pkg/remote/trans/netpoll/bytebuf_test.go index 1ecd44a2cf..dab176784d 100644 --- a/pkg/remote/trans/netpoll/bytebuf_test.go +++ b/pkg/remote/trans/netpoll/bytebuf_test.go @@ -55,6 +55,18 @@ func TestByteBuffer(t *testing.T) { test.Assert(t, err == nil) } +func TestByteBuffer_Read(t *testing.T) { + const teststr = "testing" + buf := &bytes.Buffer{} + buf.WriteString(teststr) + r := NewReaderByteBuffer(netpoll.NewReader(buf)) + b := make([]byte, len(teststr)+1) + n, err := r.Read(b) + test.Assert(t, err == nil, err) + test.Assert(t, n == len(teststr), n) + test.Assert(t, string(b[:n]) == teststr) +} + // TestWriterBuffer test writerbytebufferr return writedirect err func TestWriterBuffer(t *testing.T) { // 1. prepare mock data