diff --git a/client/rpctimeout.go b/client/rpctimeout.go index 4a8cb588ba..c60ada4f6a 100644 --- a/client/rpctimeout.go +++ b/client/rpctimeout.go @@ -96,68 +96,47 @@ func rpcTimeoutMW(mwCtx context.Context) endpoint.Middleware { } return func(next endpoint.Endpoint) endpoint.Endpoint { + backgroundEP := func(ctx context.Context, request, response interface{}) error { + err := next(ctx, request, response) + if err != nil && ctx.Err() != nil && + !kerrors.IsTimeoutError(err) && !errors.Is(err, kerrors.ErrRPCFinish) { + ri := rpcinfo.GetRPCInfo(ctx) + // error occurs after the wait goroutine returns(RPCTimeout happens), + // we should log this error for troubleshooting, or it will be discarded. + // but ErrRPCTimeout and ErrRPCFinish can be ignored: + // ErrRPCTimeout: it is same with outer timeout, here only care about non-timeout err. + // ErrRPCFinish: it happens in retry scene, previous call returns first. + var errMsg string + if ri.To().Address() != nil { + errMsg = fmt.Sprintf("KITEX: to_service=%s method=%s addr=%s error=%s", + ri.To().ServiceName(), ri.To().Method(), ri.To().Address(), err.Error()) + } else { + errMsg = fmt.Sprintf("KITEX: to_service=%s method=%s error=%s", + ri.To().ServiceName(), ri.To().Method(), err.Error()) + } + klog.CtxErrorf(ctx, "%s", errMsg) + } + return err + } return func(ctx context.Context, request, response interface{}) error { ri := rpcinfo.GetRPCInfo(ctx) if ri.Config().InteractionMode() == rpcinfo.Streaming { return next(ctx, request, response) } - tm := ri.Config().RPCTimeout() if tm > 0 { tm += moreTimeout - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, tm) - defer cancel() } - // Fast path for ctx without timeout - if ctx.Done() == nil { + // fast path if no timeout + if ctx.Done() == nil && tm <= 0 { return next(ctx, request, response) } - - var err error start := time.Now() - done := make(chan error, 1) - workerPool.GoCtx(ctx, func() { - defer func() { - if panicInfo := recover(); panicInfo != nil { - e := rpcinfo.ClientPanicToErr(ctx, panicInfo, ri, true) - done <- e - } - if err == nil || !errors.Is(err, kerrors.ErrRPCFinish) { - // Don't regards ErrRPCFinish as normal error, it happens in retry scene, - // ErrRPCFinish means previous call returns first but is decoding. - close(done) - } - }() - err = next(ctx, request, response) - if err != nil && ctx.Err() != nil && - !kerrors.IsTimeoutError(err) && !errors.Is(err, kerrors.ErrRPCFinish) { - // error occurs after the wait goroutine returns(RPCTimeout happens), - // we should log this error for troubleshooting, or it will be discarded. - // but ErrRPCTimeout and ErrRPCFinish can be ignored: - // ErrRPCTimeout: it is same with outer timeout, here only care about non-timeout err. - // ErrRPCFinish: it happens in retry scene, previous call returns first. - var errMsg string - if ri.To().Address() != nil { - errMsg = fmt.Sprintf("KITEX: to_service=%s method=%s addr=%s error=%s", - ri.To().ServiceName(), ri.To().Method(), ri.To().Address(), err.Error()) - } else { - errMsg = fmt.Sprintf("KITEX: to_service=%s method=%s error=%s", - ri.To().ServiceName(), ri.To().Method(), err.Error()) - } - klog.CtxErrorf(ctx, "%s", errMsg) - } - }) - - select { - case panicErr := <-done: - if panicErr != nil { - return panicErr - } - return err - case <-ctx.Done(): + ctx, err := workerPool.RunTask(ctx, tm, request, response, backgroundEP) + if ctx.Err() != nil { // context.Canceled or context.DeadlineExceeded? return makeTimeoutErr(ctx, start, tm) } + return err } } } diff --git a/client/rpctimeout_test.go b/client/rpctimeout_test.go index f0aacce011..c1e8ff54ec 100644 --- a/client/rpctimeout_test.go +++ b/client/rpctimeout_test.go @@ -313,3 +313,17 @@ func Test_isBusinessTimeout(t *testing.T) { }) } } + +func BenchmarkRPCTimeoutMW(b *testing.B) { + s := rpcinfo.NewEndpointInfo("mockService", "mockMethod", nil, nil) + c := rpcinfo.NewRPCConfig() + r := rpcinfo.NewRPCInfo(nil, s, nil, c, rpcinfo.NewRPCStats()) + m := rpcinfo.AsMutableRPCConfig(c) + m.SetRPCTimeout(20 * time.Millisecond) + ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), r) + mw := rpcTimeoutMW(ctx) + ep := mw(func(ctx context.Context, req, resp interface{}) (err error) { return nil }) + for i := 0; i < b.N; i++ { + ep(ctx, b, b) + } +} diff --git a/internal/wpool/context.go b/internal/wpool/context.go new file mode 100644 index 0000000000..94e16db5c8 --- /dev/null +++ b/internal/wpool/context.go @@ -0,0 +1,81 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package wpool + +import ( + "context" + "errors" + "sync" + "time" +) + +type taskContext struct { + context.Context + + dl time.Time + ch chan struct{} + + mu sync.Mutex + err error +} + +func newTaskContext(ctx context.Context, timeout time.Duration) *taskContext { + ret := &taskContext{Context: ctx, ch: make(chan struct{})} + deadline, ok := ctx.Deadline() + if ok { + ret.dl = deadline + } + if timeout != 0 { + dl := time.Now().Add(timeout) + if ret.dl.IsZero() || dl.Before(ret.dl) { + // The new deadline is sooner than the ctx one + ret.dl = dl + } + } + return ret +} + +func (p *taskContext) Deadline() (deadline time.Time, ok bool) { + return p.dl, !p.dl.IsZero() +} + +func (p *taskContext) Done() <-chan struct{} { + return p.ch +} + +// only used internally +// will not return to end user +var errTaskDone = errors.New("task done") + +func (p *taskContext) Err() error { + p.mu.Lock() + err := p.err + p.mu.Unlock() + if err == nil || err == errTaskDone { + return p.Context.Err() + } + return err +} + +func (p *taskContext) Cancel(err error) { + p.mu.Lock() + if err != nil && p.err == nil { + p.err = err + close(p.ch) + } + p.mu.Unlock() +} diff --git a/internal/wpool/pool.go b/internal/wpool/pool.go index da04ec0bda..aac3699767 100644 --- a/internal/wpool/pool.go +++ b/internal/wpool/pool.go @@ -23,36 +23,30 @@ package wpool import ( "context" - "runtime/debug" "sync/atomic" "time" - "github.com/cloudwego/kitex/pkg/klog" - "github.com/cloudwego/kitex/pkg/profiler" + "github.com/cloudwego/kitex/pkg/endpoint" ) -// Task is the function that the worker will execute. -type Task struct { - ctx context.Context - f func() -} - -// Pool is a worker pool bind with some idle goroutines. +// Pool is a worker pool for task with timeout type Pool struct { size int32 - tasks chan Task + tasks chan *Task // maxIdle is the number of the max idle workers in the pool. // if maxIdle too small, the pool works like a native 'go func()'. maxIdle int32 // maxIdleTime is the max idle time that the worker will wait for the new task. maxIdleTime time.Duration + + ticker chan struct{} } // New creates a new worker pool. func New(maxIdle int, maxIdleTime time.Duration) *Pool { return &Pool{ - tasks: make(chan Task), + tasks: make(chan *Task), maxIdle: int32(maxIdle), maxIdleTime: maxIdleTime, } @@ -63,69 +57,83 @@ func (p *Pool) Size() int32 { return atomic.LoadInt32(&p.size) } -// Go creates/reuses a worker to run task. -func (p *Pool) Go(f func()) { - p.GoCtx(context.Background(), f) +func (p *Pool) createTicker() { + // make sure previous goroutine will be closed before creating a new one + if p.ticker != nil { + close(p.ticker) + } + ch := make(chan struct{}) + p.ticker = ch + + go func() { + // if maxIdleTime=60s, maxIdle=100 + // it sends noop task every 60ms + // but always d >= 10*time.Millisecond + // this may casue goroutines take more time to exit which is acceptable. + d := p.maxIdleTime / time.Duration(p.maxIdle) / 10 + if d < 10*time.Millisecond { + d = 10 * time.Millisecond + } + tk := time.NewTicker(d) + for p.Size() > 0 { + select { + case <-tk.C: + case <-ch: + return + } + select { + case p.tasks <- nil: // noop task for checking idletime + case <-tk.C: + } + } + }() } -// GoCtx creates/reuses a worker to run task. -func (p *Pool) GoCtx(ctx context.Context, f func()) { - t := Task{ctx: ctx, f: f} - select { - case p.tasks <- t: - // reuse exist worker - return - default: - } +func (p *Pool) createWorker(t *Task) bool { + if n := atomic.AddInt32(&p.size, 1); n < p.maxIdle { + if n == 1 { + p.createTicker() + } + go func(t *Task) { + defer atomic.AddInt32(&p.size, -1) + + t.Run() + t.Recycle() - // single shot if p.size > p.maxIdle - if atomic.AddInt32(&p.size, 1) > p.maxIdle { - go func(t Task) { - defer func() { - if r := recover(); r != nil { - klog.Errorf("panic in wpool: error=%v: stack=%s", r, debug.Stack()) + lastactive := time.Now() + for t := range p.tasks { + if t == nil { // from `createTicker` func + if time.Since(lastactive) > p.maxIdleTime { + break + } + continue } - atomic.AddInt32(&p.size, -1) - }() - if profiler.IsEnabled(t.ctx) { - profiler.Tag(t.ctx) - t.f() - profiler.Untag(t.ctx) - } else { - t.f() + t.Run() + t.Recycle() + lastactive = time.Now() } }(t) - return + return true + } else { + atomic.AddInt32(&p.size, -1) + return false } +} - // background goroutines for consuming tasks - go func(t Task) { - defer func() { - if r := recover(); r != nil { - klog.Errorf("panic in wpool: error=%v: stack=%s", r, debug.Stack()) - } - atomic.AddInt32(&p.size, -1) - }() - // waiting for new task - idleTimer := time.NewTimer(p.maxIdleTime) - for { - if profiler.IsEnabled(t.ctx) { - profiler.Tag(t.ctx) - t.f() - profiler.Untag(t.ctx) - } else { - t.f() - } - idleTimer.Reset(p.maxIdleTime) - select { - case t = <-p.tasks: - case <-idleTimer.C: - // worker exits - return - } - if !idleTimer.Stop() { - <-idleTimer.C - } - } - }(t) +// RunTask creates/reuses a worker to run task. +func (p *Pool) RunTask(ctx context.Context, timeout time.Duration, + req, resp any, ep endpoint.Endpoint) (context.Context, error) { + t := NewTask(ctx, timeout, req, resp, ep) + select { + case p.tasks <- t: + return t.Wait() + default: + } + if !p.createWorker(t) { + go func(t *Task) { + t.Run() + t.Recycle() + }(t) + } + return t.Wait() } diff --git a/internal/wpool/pool_test.go b/internal/wpool/pool_test.go index b933508be8..22ef3321c7 100644 --- a/internal/wpool/pool_test.go +++ b/internal/wpool/pool_test.go @@ -28,8 +28,9 @@ import ( ) func TestWPool(t *testing.T) { - maxIdle := 1 - maxIdleTime := time.Millisecond * 500 + ctx := context.Background() + maxIdle := 2 + maxIdleTime := 100 * time.Millisecond p := New(maxIdle, maxIdleTime) var ( sum int32 @@ -39,22 +40,21 @@ func TestWPool(t *testing.T) { test.Assert(t, p.Size() == 0) for i := int32(0); i < size; i++ { wg.Add(1) - p.Go(func() { - defer wg.Done() - atomic.AddInt32(&sum, 1) - }) + go func() { + ctx, err := p.RunTask(ctx, time.Second, + nil, nil, func(ctx context.Context, req, resp interface{}) error { + defer wg.Done() + atomic.AddInt32(&sum, 1) + return nil + }) + test.Assert(t, err == nil && ctx.Err() == nil) + }() } - test.Assert(t, p.Size() != 0) - wg.Wait() test.Assert(t, atomic.LoadInt32(&sum) == size) - for p.Size() != int32(maxIdle) { // waiting for workers finished and idle workers left - runtime.Gosched() - } - for p.Size() > 0 { // waiting for idle workers timeout - time.Sleep(maxIdleTime) - } - test.Assert(t, p.Size() == 0) + test.Assert(t, p.Size() != 0) + time.Sleep(2 * maxIdleTime) + test.Assert(t, p.Size() == 0, p.Size()) } func BenchmarkWPool(b *testing.B) { @@ -62,7 +62,8 @@ func BenchmarkWPool(b *testing.B) { ctx := context.Background() p := New(maxIdleWorkers, 10*time.Millisecond) for i := 0; i < b.N; i++ { - p.GoCtx(ctx, func() {}) + p.RunTask(ctx, time.Second, nil, nil, + func(ctx context.Context, req, resp interface{}) error { return nil }) for int(p.Size()) > maxIdleWorkers { runtime.Gosched() } diff --git a/internal/wpool/task.go b/internal/wpool/task.go new file mode 100644 index 0000000000..93a82604f1 --- /dev/null +++ b/internal/wpool/task.go @@ -0,0 +1,160 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package wpool + +import ( + "context" + "fmt" + "runtime/debug" + "sync" + "sync/atomic" + "time" + + "github.com/cloudwego/kitex/pkg/endpoint" + "github.com/cloudwego/kitex/pkg/profiler" + "github.com/cloudwego/kitex/pkg/rpcinfo" +) + +var poolTask = sync.Pool{ + New: func() any { + return &Task{} + }, +} + +// Task is the function that the worker will execute. +type Task struct { + ctx *taskContext + + wg sync.WaitGroup + + req, resp any + ep endpoint.Endpoint + + err atomic.Value +} + +func NewTask(ctx context.Context, timeout time.Duration, + req, resp any, ep endpoint.Endpoint) *Task { + t := poolTask.Get().(*Task) + + // taskContext must not be reused, + // coz user may keep ref to it even though after calling endpoint.Endpoint + t.ctx = newTaskContext(ctx, timeout) + + t.req, t.resp = req, resp + t.ep = ep + t.err = atomic.Value{} + t.wg.Add(2) // 1 for Run, 2 for WaitTimeout + return t +} + +func (t *Task) Recycle() { + // make sure Run & WaitTimeout done before returning it to pool + t.wg.Wait() + + t.ctx = nil + t.req, t.resp = nil, nil + t.ep = nil + t.err = atomic.Value{} + poolTask.Put(t) +} + +func (t *Task) Cancel(err error) { + t.ctx.Cancel(err) +} + +func (t *Task) Run() { + defer func() { + if panicInfo := recover(); panicInfo != nil { + ri := rpcinfo.GetRPCInfo(t.ctx) + if ri != nil { + t.err.Store(rpcinfo.ClientPanicToErr(t.ctx, panicInfo, ri, true)) + } else { + t.err.Store(fmt.Errorf("KITEX: panic without rpcinfo, error=%v\nstack=%s", + panicInfo, debug.Stack())) + } + } + t.Cancel(errTaskDone) + t.wg.Done() + }() + var err error + if profiler.IsEnabled(t.ctx) { + profiler.Tag(t.ctx) + err = t.ep(t.ctx, t.req, t.resp) + profiler.Untag(t.ctx) + } else { + err = t.ep(t.ctx, t.req, t.resp) + } + if err != nil { + t.err.Store(err) // fix store nil value ... + } +} + +var poolTimer = sync.Pool{ + New: func() any { + return time.NewTimer(time.Second) + }, +} + +func (t *Task) Wait() (context.Context, error) { + defer t.wg.Done() + + dl, ok := t.ctx.Deadline() + if !ok { + return t.waitNoTimeout() + } + d := dl.Sub(time.Now()) + if d < 0 { + t.Cancel(context.DeadlineExceeded) + return t.ctx, t.ctx.Err() + } + + tm := poolTimer.Get().(*time.Timer) + if !tm.Stop() { + select { // it may be expired or stopped + case <-tm.C: + default: + } + } + defer poolTimer.Put(tm) + tm.Reset(d) + select { + case <-t.ctx.Done(): + // Run returned before timeout + case <-t.ctx.Context.Done(): + t.Cancel(t.ctx.Context.Err()) + case <-tm.C: + t.Cancel(context.DeadlineExceeded) + } + if v := t.err.Load(); v != nil { + return t.ctx, v.(error) + } + return t.ctx, t.ctx.Err() +} + +func (t *Task) waitNoTimeout() (context.Context, error) { + select { + case <-t.ctx.Done(): + // Run returned before timeout + case <-t.ctx.Context.Done(): + t.Cancel(t.ctx.Context.Err()) + } + if v := t.err.Load(); v != nil { + return t.ctx, v.(error) + } + return t.ctx, t.ctx.Err() +} diff --git a/internal/wpool/task_test.go b/internal/wpool/task_test.go new file mode 100644 index 0000000000..dfbcbda9b3 --- /dev/null +++ b/internal/wpool/task_test.go @@ -0,0 +1,104 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package wpool + +import ( + "context" + "errors" + "strings" + "testing" + "time" + + "github.com/cloudwego/kitex/internal/test" +) + +func TestTask(t *testing.T) { + t.Run("ReturnErr", func(t *testing.T) { + returnErr := errors.New("ep error") + ctx := context.Background() + timeout := 20 * time.Millisecond + p := NewTask(ctx, timeout, nil, nil, + func(ctx context.Context, _, _ any) error { + time.Sleep(timeout / 2) + return returnErr + }) + + go p.Run() + ctx, err := p.Wait() + p.Recycle() + test.Assert(t, returnErr == err, err) + test.Assert(t, ctx.Err() == nil, ctx.Err()) + }) + + t.Run("Timeout", func(t *testing.T) { + ctx := context.Background() + timeout := 20 * time.Millisecond + dl := time.Now().Add(timeout) + returned := false + p := NewTask(ctx, timeout, nil, nil, func(ctx context.Context, _, _ any) error { + d, ok := ctx.Deadline() + test.Assert(t, ok) + test.Assert(t, d.Sub(dl) < timeout/2) + <-ctx.Done() + returned = true + return errors.New("ep error") + }) + + go p.Run() + t0 := time.Now() + ctx, err := p.Wait() + p.Recycle() + test.Assert(t, errors.Is(err, context.DeadlineExceeded), err) + test.Assert(t, errors.Is(ctx.Err(), context.DeadlineExceeded), ctx.Err()) + test.Assert(t, time.Since(t0)-timeout < timeout/4) + time.Sleep(timeout / 4) + test.Assert(t, returned == true) + }) + + t.Run("ParentCtxDone", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + p := NewTask(ctx, 0, nil, nil, + func(ctx context.Context, _, _ any) error { + cancel() + time.Sleep(10 * time.Millisecond) + return nil + }) + + go p.Run() + ctx, err := p.Wait() + p.Recycle() + test.Assert(t, err == context.Canceled, err) + test.Assert(t, errors.Is(ctx.Err(), context.Canceled), ctx.Err()) + }) + + t.Run("Panic", func(t *testing.T) { + ctx := context.Background() + timeout := 20 * time.Millisecond + p := NewTask(ctx, timeout, nil, nil, + func(ctx context.Context, _, _ any) error { + panic("testpanic") + return nil + }) + + go p.Run() + ctx, err := p.Wait() + p.Recycle() + test.Assert(t, err != nil && strings.Contains(err.Error(), "testpanic"), err) + test.Assert(t, ctx.Err() == nil) + }) + +}