From c2d4d48833e6a62647af553861ffd64ca2a5d209 Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Wed, 16 Oct 2024 18:00:50 +0800 Subject: [PATCH] perf(client): new rpctimeout impl MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit goos: darwin goarch: arm64 pkg: github.com/cloudwego/kitex/client │ ./old.txt │ ./new.txt │ │ sec/op │ sec/op vs base │ RPCTimeoutMW-12 1167.5n ± 1% 797.1n ± 1% -31.73% (p=0.002 n=6) │ ./old.txt │ ./new.txt │ │ B/op │ B/op vs base │ RPCTimeoutMW-12 560.0 ± 0% 176.0 ± 0% -68.57% (p=0.002 n=6) │ ./old.txt │ ./new.txt │ │ allocs/op │ allocs/op vs base │ RPCTimeoutMW-12 9.000 ± 0% 2.000 ± 0% -77.78% (p=0.002 n=6) --- client/rpctimeout.go | 75 ++++++----------- client/rpctimeout_test.go | 14 ++++ internal/wpool/context.go | 81 ++++++++++++++++++ internal/wpool/pool.go | 149 +++++++++++++++++---------------- internal/wpool/pool_test.go | 44 +++++----- internal/wpool/task.go | 161 ++++++++++++++++++++++++++++++++++++ internal/wpool/task_test.go | 101 ++++++++++++++++++++++ 7 files changed, 488 insertions(+), 137 deletions(-) create mode 100644 internal/wpool/context.go create mode 100644 internal/wpool/task.go create mode 100644 internal/wpool/task_test.go diff --git a/client/rpctimeout.go b/client/rpctimeout.go index 4a8cb588ba..c60ada4f6a 100644 --- a/client/rpctimeout.go +++ b/client/rpctimeout.go @@ -96,68 +96,47 @@ func rpcTimeoutMW(mwCtx context.Context) endpoint.Middleware { } return func(next endpoint.Endpoint) endpoint.Endpoint { + backgroundEP := func(ctx context.Context, request, response interface{}) error { + err := next(ctx, request, response) + if err != nil && ctx.Err() != nil && + !kerrors.IsTimeoutError(err) && !errors.Is(err, kerrors.ErrRPCFinish) { + ri := rpcinfo.GetRPCInfo(ctx) + // error occurs after the wait goroutine returns(RPCTimeout happens), + // we should log this error for troubleshooting, or it will be discarded. + // but ErrRPCTimeout and ErrRPCFinish can be ignored: + // ErrRPCTimeout: it is same with outer timeout, here only care about non-timeout err. + // ErrRPCFinish: it happens in retry scene, previous call returns first. + var errMsg string + if ri.To().Address() != nil { + errMsg = fmt.Sprintf("KITEX: to_service=%s method=%s addr=%s error=%s", + ri.To().ServiceName(), ri.To().Method(), ri.To().Address(), err.Error()) + } else { + errMsg = fmt.Sprintf("KITEX: to_service=%s method=%s error=%s", + ri.To().ServiceName(), ri.To().Method(), err.Error()) + } + klog.CtxErrorf(ctx, "%s", errMsg) + } + return err + } return func(ctx context.Context, request, response interface{}) error { ri := rpcinfo.GetRPCInfo(ctx) if ri.Config().InteractionMode() == rpcinfo.Streaming { return next(ctx, request, response) } - tm := ri.Config().RPCTimeout() if tm > 0 { tm += moreTimeout - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, tm) - defer cancel() } - // Fast path for ctx without timeout - if ctx.Done() == nil { + // fast path if no timeout + if ctx.Done() == nil && tm <= 0 { return next(ctx, request, response) } - - var err error start := time.Now() - done := make(chan error, 1) - workerPool.GoCtx(ctx, func() { - defer func() { - if panicInfo := recover(); panicInfo != nil { - e := rpcinfo.ClientPanicToErr(ctx, panicInfo, ri, true) - done <- e - } - if err == nil || !errors.Is(err, kerrors.ErrRPCFinish) { - // Don't regards ErrRPCFinish as normal error, it happens in retry scene, - // ErrRPCFinish means previous call returns first but is decoding. - close(done) - } - }() - err = next(ctx, request, response) - if err != nil && ctx.Err() != nil && - !kerrors.IsTimeoutError(err) && !errors.Is(err, kerrors.ErrRPCFinish) { - // error occurs after the wait goroutine returns(RPCTimeout happens), - // we should log this error for troubleshooting, or it will be discarded. - // but ErrRPCTimeout and ErrRPCFinish can be ignored: - // ErrRPCTimeout: it is same with outer timeout, here only care about non-timeout err. - // ErrRPCFinish: it happens in retry scene, previous call returns first. - var errMsg string - if ri.To().Address() != nil { - errMsg = fmt.Sprintf("KITEX: to_service=%s method=%s addr=%s error=%s", - ri.To().ServiceName(), ri.To().Method(), ri.To().Address(), err.Error()) - } else { - errMsg = fmt.Sprintf("KITEX: to_service=%s method=%s error=%s", - ri.To().ServiceName(), ri.To().Method(), err.Error()) - } - klog.CtxErrorf(ctx, "%s", errMsg) - } - }) - - select { - case panicErr := <-done: - if panicErr != nil { - return panicErr - } - return err - case <-ctx.Done(): + ctx, err := workerPool.RunTask(ctx, tm, request, response, backgroundEP) + if ctx.Err() != nil { // context.Canceled or context.DeadlineExceeded? return makeTimeoutErr(ctx, start, tm) } + return err } } } diff --git a/client/rpctimeout_test.go b/client/rpctimeout_test.go index f0aacce011..c1e8ff54ec 100644 --- a/client/rpctimeout_test.go +++ b/client/rpctimeout_test.go @@ -313,3 +313,17 @@ func Test_isBusinessTimeout(t *testing.T) { }) } } + +func BenchmarkRPCTimeoutMW(b *testing.B) { + s := rpcinfo.NewEndpointInfo("mockService", "mockMethod", nil, nil) + c := rpcinfo.NewRPCConfig() + r := rpcinfo.NewRPCInfo(nil, s, nil, c, rpcinfo.NewRPCStats()) + m := rpcinfo.AsMutableRPCConfig(c) + m.SetRPCTimeout(20 * time.Millisecond) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), r) + mw := rpcTimeoutMW(ctx) + ep := mw(func(ctx context.Context, req, resp interface{}) (err error) { return nil }) + for i := 0; i < b.N; i++ { + ep(ctx, b, b) + } +} diff --git a/internal/wpool/context.go b/internal/wpool/context.go new file mode 100644 index 0000000000..94e16db5c8 --- /dev/null +++ b/internal/wpool/context.go @@ -0,0 +1,81 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package wpool + +import ( + "context" + "errors" + "sync" + "time" +) + +type taskContext struct { + context.Context + + dl time.Time + ch chan struct{} + + mu sync.Mutex + err error +} + +func newTaskContext(ctx context.Context, timeout time.Duration) *taskContext { + ret := &taskContext{Context: ctx, ch: make(chan struct{})} + deadline, ok := ctx.Deadline() + if ok { + ret.dl = deadline + } + if timeout != 0 { + dl := time.Now().Add(timeout) + if ret.dl.IsZero() || dl.Before(ret.dl) { + // The new deadline is sooner than the ctx one + ret.dl = dl + } + } + return ret +} + +func (p *taskContext) Deadline() (deadline time.Time, ok bool) { + return p.dl, !p.dl.IsZero() +} + +func (p *taskContext) Done() <-chan struct{} { + return p.ch +} + +// only used internally +// will not return to end user +var errTaskDone = errors.New("task done") + +func (p *taskContext) Err() error { + p.mu.Lock() + err := p.err + p.mu.Unlock() + if err == nil || err == errTaskDone { + return p.Context.Err() + } + return err +} + +func (p *taskContext) Cancel(err error) { + p.mu.Lock() + if err != nil && p.err == nil { + p.err = err + close(p.ch) + } + p.mu.Unlock() +} diff --git a/internal/wpool/pool.go b/internal/wpool/pool.go index da04ec0bda..f0f5685f51 100644 --- a/internal/wpool/pool.go +++ b/internal/wpool/pool.go @@ -23,36 +23,32 @@ package wpool import ( "context" - "runtime/debug" + "sync" "sync/atomic" "time" - "github.com/cloudwego/kitex/pkg/klog" - "github.com/cloudwego/kitex/pkg/profiler" + "github.com/cloudwego/kitex/pkg/endpoint" ) -// Task is the function that the worker will execute. -type Task struct { - ctx context.Context - f func() -} - -// Pool is a worker pool bind with some idle goroutines. +// Pool is a worker pool for task with timeout type Pool struct { size int32 - tasks chan Task + tasks chan *task // maxIdle is the number of the max idle workers in the pool. // if maxIdle too small, the pool works like a native 'go func()'. maxIdle int32 // maxIdleTime is the max idle time that the worker will wait for the new task. maxIdleTime time.Duration + + mu sync.Mutex + ticker chan struct{} } // New creates a new worker pool. func New(maxIdle int, maxIdleTime time.Duration) *Pool { return &Pool{ - tasks: make(chan Task), + tasks: make(chan *task), maxIdle: int32(maxIdle), maxIdleTime: maxIdleTime, } @@ -63,69 +59,84 @@ func (p *Pool) Size() int32 { return atomic.LoadInt32(&p.size) } -// Go creates/reuses a worker to run task. -func (p *Pool) Go(f func()) { - p.GoCtx(context.Background(), f) -} - -// GoCtx creates/reuses a worker to run task. -func (p *Pool) GoCtx(ctx context.Context, f func()) { - t := Task{ctx: ctx, f: f} - select { - case p.tasks <- t: - // reuse exist worker - return - default: - } +func (p *Pool) createTicker() { + p.mu.Lock() + defer p.mu.Unlock() - // single shot if p.size > p.maxIdle - if atomic.AddInt32(&p.size, 1) > p.maxIdle { - go func(t Task) { - defer func() { - if r := recover(); r != nil { - klog.Errorf("panic in wpool: error=%v: stack=%s", r, debug.Stack()) - } - atomic.AddInt32(&p.size, -1) - }() - if profiler.IsEnabled(t.ctx) { - profiler.Tag(t.ctx) - t.f() - profiler.Untag(t.ctx) - } else { - t.f() - } - }(t) - return + // make sure previous goroutine will be closed before creating a new one + if p.ticker != nil { + close(p.ticker) } + ch := make(chan struct{}) + p.ticker = ch - // background goroutines for consuming tasks - go func(t Task) { - defer func() { - if r := recover(); r != nil { - klog.Errorf("panic in wpool: error=%v: stack=%s", r, debug.Stack()) - } - atomic.AddInt32(&p.size, -1) - }() - // waiting for new task - idleTimer := time.NewTimer(p.maxIdleTime) - for { - if profiler.IsEnabled(t.ctx) { - profiler.Tag(t.ctx) - t.f() - profiler.Untag(t.ctx) - } else { - t.f() - } - idleTimer.Reset(p.maxIdleTime) + go func(done <-chan struct{}) { + // if maxIdleTime=60s, maxIdle=100 + // it sends noop task every 60ms + // but always d >= 10*time.Millisecond + // this may cause goroutines take more time to exit which is acceptable. + d := p.maxIdleTime / time.Duration(p.maxIdle) / 10 + if d < 10*time.Millisecond { + d = 10 * time.Millisecond + } + tk := time.NewTicker(d) + for p.Size() > 0 { select { - case t = <-p.tasks: - case <-idleTimer.C: - // worker exits + case <-tk.C: + case <-done: return } - if !idleTimer.Stop() { - <-idleTimer.C + select { + case p.tasks <- nil: // noop task for checking idletime + case <-tk.C: } } - }(t) + }(ch) +} + +func (p *Pool) createWorker(t *task) bool { + if n := atomic.AddInt32(&p.size, 1); n < p.maxIdle { + if n == 1 { + p.createTicker() + } + go func(t *task) { + defer atomic.AddInt32(&p.size, -1) + + t.Run() + + lastactive := time.Now() + for t := range p.tasks { + if t == nil { // from `createTicker` func + if time.Since(lastactive) > p.maxIdleTime { + break + } + continue + } + t.Run() + lastactive = time.Now() + } + }(t) + return true + } else { + atomic.AddInt32(&p.size, -1) + return false + } +} + +// RunTask creates/reuses a worker to run task. +func (p *Pool) RunTask(ctx context.Context, timeout time.Duration, + req, resp any, ep endpoint.Endpoint, +) (context.Context, error) { + t := newTask(ctx, timeout, req, resp, ep) + select { + case p.tasks <- t: + return t.Wait() + default: + } + if !p.createWorker(t) { + // if created worker, t.Run() will be called in worker goroutine + // if NOT, we should go t.Run() here. + go t.Run() + } + return t.Wait() } diff --git a/internal/wpool/pool_test.go b/internal/wpool/pool_test.go index b933508be8..a982d0893f 100644 --- a/internal/wpool/pool_test.go +++ b/internal/wpool/pool_test.go @@ -28,8 +28,9 @@ import ( ) func TestWPool(t *testing.T) { - maxIdle := 1 - maxIdleTime := time.Millisecond * 500 + ctx := context.Background() + maxIdle := 2 + maxIdleTime := 100 * time.Millisecond p := New(maxIdle, maxIdleTime) var ( sum int32 @@ -39,32 +40,35 @@ func TestWPool(t *testing.T) { test.Assert(t, p.Size() == 0) for i := int32(0); i < size; i++ { wg.Add(1) - p.Go(func() { - defer wg.Done() - atomic.AddInt32(&sum, 1) - }) + go func() { + ctx, err := p.RunTask(ctx, time.Second, + nil, nil, func(ctx context.Context, req, resp interface{}) error { + defer wg.Done() + atomic.AddInt32(&sum, 1) + return nil + }) + test.Assert(t, err == nil && ctx.Err() == nil) + }() } - test.Assert(t, p.Size() != 0) - wg.Wait() test.Assert(t, atomic.LoadInt32(&sum) == size) - for p.Size() != int32(maxIdle) { // waiting for workers finished and idle workers left - runtime.Gosched() - } - for p.Size() > 0 { // waiting for idle workers timeout - time.Sleep(maxIdleTime) - } - test.Assert(t, p.Size() == 0) + test.Assert(t, p.Size() != 0) + time.Sleep(2 * maxIdleTime) + test.Assert(t, p.Size() == 0, p.Size()) } +func noop(ctx context.Context, req, resp interface{}) error { return nil } + func BenchmarkWPool(b *testing.B) { maxIdleWorkers := runtime.GOMAXPROCS(0) ctx := context.Background() p := New(maxIdleWorkers, 10*time.Millisecond) - for i := 0; i < b.N; i++ { - p.GoCtx(ctx, func() {}) - for int(p.Size()) > maxIdleWorkers { - runtime.Gosched() + b.RunParallel(func(b *testing.PB) { + for b.Next() { + p.RunTask(ctx, time.Second, nil, nil, noop) + for int(p.Size()) > maxIdleWorkers { + runtime.Gosched() + } } - } + }) } diff --git a/internal/wpool/task.go b/internal/wpool/task.go new file mode 100644 index 0000000000..a9b8214e2f --- /dev/null +++ b/internal/wpool/task.go @@ -0,0 +1,161 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package wpool + +import ( + "context" + "fmt" + "runtime/debug" + "sync" + "sync/atomic" + "time" + + "github.com/cloudwego/kitex/pkg/endpoint" + "github.com/cloudwego/kitex/pkg/profiler" + "github.com/cloudwego/kitex/pkg/rpcinfo" +) + +var poolTask = sync.Pool{ + New: func() any { + return &task{} + }, +} + +// task is the function that the worker will execute. +type task struct { + ctx *taskContext + + wg sync.WaitGroup + + req, resp any + ep endpoint.Endpoint + + err atomic.Value +} + +func newTask(ctx context.Context, timeout time.Duration, + req, resp any, ep endpoint.Endpoint, +) *task { + t := poolTask.Get().(*task) + + // taskContext must not be reused, + // coz user may keep ref to it even though after calling endpoint.Endpoint + t.ctx = newTaskContext(ctx, timeout) + + t.req, t.resp = req, resp + t.ep = ep + t.err = atomic.Value{} + t.wg.Add(1) // for Wait, Wait must be called before Recycle() + return t +} + +func (t *task) recycle() { + // make sure Wait done before returning it to pool + t.wg.Wait() + + t.ctx = nil + t.req, t.resp = nil, nil + t.ep = nil + t.err = atomic.Value{} + poolTask.Put(t) +} + +func (t *task) Cancel(err error) { + t.ctx.Cancel(err) +} + +// Run must be called in a separated goroutine +func (t *task) Run() { + defer func() { + if panicInfo := recover(); panicInfo != nil { + ri := rpcinfo.GetRPCInfo(t.ctx) + if ri != nil { + t.err.Store(rpcinfo.ClientPanicToErr(t.ctx, panicInfo, ri, true)) + } else { + t.err.Store(fmt.Errorf("KITEX: panic without rpcinfo, error=%v\nstack=%s", + panicInfo, debug.Stack())) + } + } + t.Cancel(errTaskDone) + t.recycle() + }() + var err error + if profiler.IsEnabled(t.ctx) { + profiler.Tag(t.ctx) + err = t.ep(t.ctx, t.req, t.resp) + profiler.Untag(t.ctx) + } else { + err = t.ep(t.ctx, t.req, t.resp) + } + if err != nil { + t.err.Store(err) // fix store nil value ... + } +} + +var poolTimer = sync.Pool{ + New: func() any { + return time.NewTimer(time.Second) + }, +} + +// Wait waits Run finishes and returns result +func (t *task) Wait() (context.Context, error) { + defer t.wg.Done() + dl, ok := t.ctx.Deadline() + if !ok { + return t.waitNoTimeout() + } + d := time.Until(dl) + if d < 0 { + t.Cancel(context.DeadlineExceeded) + return t.ctx, t.ctx.Err() + } + tm := poolTimer.Get().(*time.Timer) + if !tm.Stop() { + select { // it may be expired or stopped + case <-tm.C: + default: + } + } + defer poolTimer.Put(tm) + tm.Reset(d) + select { + case <-t.ctx.Done(): + // Run returned before timeout + case <-t.ctx.Context.Done(): + t.Cancel(t.ctx.Context.Err()) + case <-tm.C: + t.Cancel(context.DeadlineExceeded) + } + if v := t.err.Load(); v != nil { + return t.ctx, v.(error) + } + return t.ctx, t.ctx.Err() +} + +func (t *task) waitNoTimeout() (context.Context, error) { + select { + case <-t.ctx.Done(): + // Run returned before timeout + case <-t.ctx.Context.Done(): + t.Cancel(t.ctx.Context.Err()) + } + if v := t.err.Load(); v != nil { + return t.ctx, v.(error) + } + return t.ctx, t.ctx.Err() +} diff --git a/internal/wpool/task_test.go b/internal/wpool/task_test.go new file mode 100644 index 0000000000..c3ca2cbb57 --- /dev/null +++ b/internal/wpool/task_test.go @@ -0,0 +1,101 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package wpool + +import ( + "context" + "errors" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/cloudwego/kitex/internal/test" +) + +func TestTask(t *testing.T) { + t.Run("ReturnErr", func(t *testing.T) { + returnErr := errors.New("ep error") + ctx := context.Background() + timeout := 20 * time.Millisecond + p := newTask(ctx, timeout, nil, nil, + func(ctx context.Context, _, _ any) error { + time.Sleep(timeout / 2) + return returnErr + }) + + go p.Run() + ctx, err := p.Wait() + test.Assert(t, returnErr == err, err) + test.Assert(t, ctx.Err() == nil, ctx.Err()) + }) + + t.Run("Timeout", func(t *testing.T) { + ctx := context.Background() + timeout := 50 * time.Millisecond + dl := time.Now().Add(timeout) + var returned atomic.Value + p := newTask(ctx, timeout, nil, nil, func(ctx context.Context, _, _ any) error { + d, ok := ctx.Deadline() + test.Assert(t, ok) + test.Assert(t, d.Sub(dl) < timeout/4) + <-ctx.Done() + returned.Store(true) + time.Sleep(timeout / 4) + return errors.New("ep error") + }) + + t0 := time.Now() + go p.Run() + ctx, err := p.Wait() + t1 := time.Now() + test.Assert(t, errors.Is(err, context.DeadlineExceeded), err) + test.Assert(t, errors.Is(ctx.Err(), context.DeadlineExceeded), ctx.Err()) + test.Assert(t, t1.Sub(t0)-timeout < timeout/4) + time.Sleep(timeout / 2) + test.Assert(t, returned.Load() != nil) + }) + + t.Run("ParentCtxDone", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + p := newTask(ctx, 0, nil, nil, + func(ctx context.Context, _, _ any) error { + cancel() + time.Sleep(10 * time.Millisecond) + return nil + }) + + go p.Run() + ctx, err := p.Wait() + test.Assert(t, err == context.Canceled, err) + test.Assert(t, errors.Is(ctx.Err(), context.Canceled), ctx.Err()) + }) + + t.Run("Panic", func(t *testing.T) { + ctx := context.Background() + timeout := 20 * time.Millisecond + p := newTask(ctx, timeout, nil, nil, + func(ctx context.Context, _, _ any) error { + panic("testpanic") + }) + + go p.Run() + ctx, err := p.Wait() + test.Assert(t, err != nil && strings.Contains(err.Error(), "testpanic"), err) + test.Assert(t, ctx.Err() == nil) + }) +}