Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(client): new rpctimeout impl #1581

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 27 additions & 48 deletions client/rpctimeout.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,68 +96,47 @@ func rpcTimeoutMW(mwCtx context.Context) endpoint.Middleware {
}

return func(next endpoint.Endpoint) endpoint.Endpoint {
backgroundEP := func(ctx context.Context, request, response interface{}) error {
err := next(ctx, request, response)
if err != nil && ctx.Err() != nil &&
!kerrors.IsTimeoutError(err) && !errors.Is(err, kerrors.ErrRPCFinish) {
ri := rpcinfo.GetRPCInfo(ctx)
// error occurs after the wait goroutine returns(RPCTimeout happens),
// we should log this error for troubleshooting, or it will be discarded.
// but ErrRPCTimeout and ErrRPCFinish can be ignored:
// ErrRPCTimeout: it is same with outer timeout, here only care about non-timeout err.
// ErrRPCFinish: it happens in retry scene, previous call returns first.
var errMsg string
if ri.To().Address() != nil {
errMsg = fmt.Sprintf("KITEX: to_service=%s method=%s addr=%s error=%s",
ri.To().ServiceName(), ri.To().Method(), ri.To().Address(), err.Error())
} else {
errMsg = fmt.Sprintf("KITEX: to_service=%s method=%s error=%s",
ri.To().ServiceName(), ri.To().Method(), err.Error())
}
klog.CtxErrorf(ctx, "%s", errMsg)
}
return err
}
return func(ctx context.Context, request, response interface{}) error {
ri := rpcinfo.GetRPCInfo(ctx)
if ri.Config().InteractionMode() == rpcinfo.Streaming {
return next(ctx, request, response)
}

tm := ri.Config().RPCTimeout()
if tm > 0 {
tm += moreTimeout
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, tm)
defer cancel()
}
// Fast path for ctx without timeout
if ctx.Done() == nil {
// fast path if no timeout
if ctx.Done() == nil && tm <= 0 {
return next(ctx, request, response)
}

var err error
start := time.Now()
done := make(chan error, 1)
workerPool.GoCtx(ctx, func() {
defer func() {
if panicInfo := recover(); panicInfo != nil {
e := rpcinfo.ClientPanicToErr(ctx, panicInfo, ri, true)
done <- e
}
if err == nil || !errors.Is(err, kerrors.ErrRPCFinish) {
// Don't regards ErrRPCFinish as normal error, it happens in retry scene,
// ErrRPCFinish means previous call returns first but is decoding.
close(done)
}
}()
err = next(ctx, request, response)
if err != nil && ctx.Err() != nil &&
!kerrors.IsTimeoutError(err) && !errors.Is(err, kerrors.ErrRPCFinish) {
// error occurs after the wait goroutine returns(RPCTimeout happens),
// we should log this error for troubleshooting, or it will be discarded.
// but ErrRPCTimeout and ErrRPCFinish can be ignored:
// ErrRPCTimeout: it is same with outer timeout, here only care about non-timeout err.
// ErrRPCFinish: it happens in retry scene, previous call returns first.
var errMsg string
if ri.To().Address() != nil {
errMsg = fmt.Sprintf("KITEX: to_service=%s method=%s addr=%s error=%s",
ri.To().ServiceName(), ri.To().Method(), ri.To().Address(), err.Error())
} else {
errMsg = fmt.Sprintf("KITEX: to_service=%s method=%s error=%s",
ri.To().ServiceName(), ri.To().Method(), err.Error())
}
klog.CtxErrorf(ctx, "%s", errMsg)
}
})

select {
case panicErr := <-done:
if panicErr != nil {
return panicErr
}
return err
case <-ctx.Done():
ctx, err := workerPool.RunTask(ctx, tm, request, response, backgroundEP)
if ctx.Err() != nil { // context.Canceled or context.DeadlineExceeded?
return makeTimeoutErr(ctx, start, tm)
}
return err
}
}
}
14 changes: 14 additions & 0 deletions client/rpctimeout_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,3 +313,17 @@ func Test_isBusinessTimeout(t *testing.T) {
})
}
}

func BenchmarkRPCTimeoutMW(b *testing.B) {
s := rpcinfo.NewEndpointInfo("mockService", "mockMethod", nil, nil)
c := rpcinfo.NewRPCConfig()
r := rpcinfo.NewRPCInfo(nil, s, nil, c, rpcinfo.NewRPCStats())
m := rpcinfo.AsMutableRPCConfig(c)
m.SetRPCTimeout(20 * time.Millisecond)
ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), r)
mw := rpcTimeoutMW(ctx)
ep := mw(func(ctx context.Context, req, resp interface{}) (err error) { return nil })
for i := 0; i < b.N; i++ {
ep(ctx, b, b)
}
}
81 changes: 81 additions & 0 deletions internal/wpool/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package wpool

import (
"context"
"errors"
"sync"
"time"
)

type taskContext struct {
context.Context

dl time.Time
ch chan struct{}

mu sync.Mutex
err error
}

func newTaskContext(ctx context.Context, timeout time.Duration) *taskContext {
ret := &taskContext{Context: ctx, ch: make(chan struct{})}
deadline, ok := ctx.Deadline()
if ok {
ret.dl = deadline
}
if timeout != 0 {
dl := time.Now().Add(timeout)
if ret.dl.IsZero() || dl.Before(ret.dl) {
// The new deadline is sooner than the ctx one
ret.dl = dl
}
}
return ret
}

func (p *taskContext) Deadline() (deadline time.Time, ok bool) {
return p.dl, !p.dl.IsZero()
}

func (p *taskContext) Done() <-chan struct{} {
return p.ch
}

// only used internally
// will not return to end user
var errTaskDone = errors.New("task done")

func (p *taskContext) Err() error {
p.mu.Lock()
err := p.err
p.mu.Unlock()
if err == nil || err == errTaskDone {
return p.Context.Err()
}
return err
}

func (p *taskContext) Cancel(err error) {
p.mu.Lock()
if err != nil && p.err == nil {
p.err = err
close(p.ch)
}
p.mu.Unlock()
}
149 changes: 80 additions & 69 deletions internal/wpool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,36 +23,32 @@ package wpool

import (
"context"
"runtime/debug"
"sync"
"sync/atomic"
"time"

"github.com/cloudwego/kitex/pkg/klog"
"github.com/cloudwego/kitex/pkg/profiler"
"github.com/cloudwego/kitex/pkg/endpoint"
)

// Task is the function that the worker will execute.
type Task struct {
ctx context.Context
f func()
}

// Pool is a worker pool bind with some idle goroutines.
// Pool is a worker pool for task with timeout
type Pool struct {
size int32
tasks chan Task
tasks chan *task

// maxIdle is the number of the max idle workers in the pool.
// if maxIdle too small, the pool works like a native 'go func()'.
maxIdle int32
// maxIdleTime is the max idle time that the worker will wait for the new task.
maxIdleTime time.Duration

mu sync.Mutex
ticker chan struct{}
}

// New creates a new worker pool.
func New(maxIdle int, maxIdleTime time.Duration) *Pool {
return &Pool{
tasks: make(chan Task),
tasks: make(chan *task),
maxIdle: int32(maxIdle),
maxIdleTime: maxIdleTime,
}
Expand All @@ -63,69 +59,84 @@ func (p *Pool) Size() int32 {
return atomic.LoadInt32(&p.size)
}

// Go creates/reuses a worker to run task.
func (p *Pool) Go(f func()) {
p.GoCtx(context.Background(), f)
}

// GoCtx creates/reuses a worker to run task.
func (p *Pool) GoCtx(ctx context.Context, f func()) {
t := Task{ctx: ctx, f: f}
select {
case p.tasks <- t:
// reuse exist worker
return
default:
}
func (p *Pool) createTicker() {
p.mu.Lock()
defer p.mu.Unlock()

// single shot if p.size > p.maxIdle
if atomic.AddInt32(&p.size, 1) > p.maxIdle {
go func(t Task) {
defer func() {
if r := recover(); r != nil {
klog.Errorf("panic in wpool: error=%v: stack=%s", r, debug.Stack())
}
atomic.AddInt32(&p.size, -1)
}()
if profiler.IsEnabled(t.ctx) {
profiler.Tag(t.ctx)
t.f()
profiler.Untag(t.ctx)
} else {
t.f()
}
}(t)
return
// make sure previous goroutine will be closed before creating a new one
if p.ticker != nil {
close(p.ticker)
}
ch := make(chan struct{})
p.ticker = ch

// background goroutines for consuming tasks
go func(t Task) {
defer func() {
if r := recover(); r != nil {
klog.Errorf("panic in wpool: error=%v: stack=%s", r, debug.Stack())
}
atomic.AddInt32(&p.size, -1)
}()
// waiting for new task
idleTimer := time.NewTimer(p.maxIdleTime)
for {
if profiler.IsEnabled(t.ctx) {
profiler.Tag(t.ctx)
t.f()
profiler.Untag(t.ctx)
} else {
t.f()
}
idleTimer.Reset(p.maxIdleTime)
go func(done <-chan struct{}) {
// if maxIdleTime=60s, maxIdle=100
// it sends noop task every 60ms
// but always d >= 10*time.Millisecond
// this may cause goroutines take more time to exit which is acceptable.
d := p.maxIdleTime / time.Duration(p.maxIdle) / 10
if d < 10*time.Millisecond {
d = 10 * time.Millisecond
}
tk := time.NewTicker(d)
for p.Size() > 0 {
select {
case t = <-p.tasks:
case <-idleTimer.C:
// worker exits
case <-tk.C:
case <-done:
return
}
if !idleTimer.Stop() {
<-idleTimer.C
select {
case p.tasks <- nil: // noop task for checking idletime
case <-tk.C:
}
}
}(t)
}(ch)
}

func (p *Pool) createWorker(t *task) bool {
if n := atomic.AddInt32(&p.size, 1); n < p.maxIdle {
if n == 1 {
p.createTicker()
}
go func(t *task) {
defer atomic.AddInt32(&p.size, -1)

t.Run()

lastactive := time.Now()
for t := range p.tasks {
if t == nil { // from `createTicker` func
if time.Since(lastactive) > p.maxIdleTime {
break
}
continue
}
t.Run()
lastactive = time.Now()
}
}(t)
return true
} else {
atomic.AddInt32(&p.size, -1)
return false
}
}

// RunTask creates/reuses a worker to run task.
func (p *Pool) RunTask(ctx context.Context, timeout time.Duration,
req, resp any, ep endpoint.Endpoint,
) (context.Context, error) {
t := newTask(ctx, timeout, req, resp, ep)
select {
case p.tasks <- t:
return t.Wait()
default:
}
if !p.createWorker(t) {
// if created worker, t.Run() will be called in worker goroutine
// if NOT, we should go t.Run() here.
go t.Run()
}
return t.Wait()
}
Loading