From 41140a2c03f62d4e212777d72728a6f799c8dd2c Mon Sep 17 00:00:00 2001 From: Ilya Ozherelyev Date: Fri, 23 Aug 2024 12:01:13 +0200 Subject: [PATCH 1/2] Introduce context aware helpers for context related testing --- clockwork.go | 2 ++ context.go | 89 ++++++++++++++++++++++++++++++++++++++++++++++ context_test.go | 94 +++++++++++++++++++++++++++++++++++++++++++++++++ timer.go | 8 ++++- 4 files changed, 192 insertions(+), 1 deletion(-) diff --git a/clockwork.go b/clockwork.go index df02ba9..f8a8ada 100644 --- a/clockwork.go +++ b/clockwork.go @@ -19,6 +19,8 @@ type Clock interface { NewTicker(d time.Duration) Ticker NewTimer(d time.Duration) Timer AfterFunc(d time.Duration, f func()) Timer + WithTimeout(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) + WithDeadline(parent context.Context, deadline time.Time) (context.Context, context.CancelFunc) } // NewRealClock returns a Clock which simply delegates calls to the actual time diff --git a/context.go b/context.go index 40906e7..c3bc9bf 100644 --- a/context.go +++ b/context.go @@ -2,6 +2,9 @@ package clockwork import ( "context" + "fmt" + "sync" + "time" ) // contextKey is private to this package so we can ensure uniqueness here. This @@ -28,3 +31,89 @@ func FromContext(ctx context.Context) Clock { } return NewRealClock() } + +func (rc *realClock) WithTimeout(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { + return context.WithTimeout(parent, timeout) +} + +func (rc *realClock) WithDeadline(parent context.Context, deadline time.Time) (context.Context, context.CancelFunc) { + return context.WithDeadline(parent, deadline) +} + +func (fc *FakeClock) WithTimeout(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { + return fc.WithDeadline(parent, fc.Now().Add(timeout)) +} + +func (fc *FakeClock) WithDeadline(parent context.Context, deadline time.Time) (context.Context, context.CancelFunc) { + if cur, ok := parent.Deadline(); ok && cur.Before(deadline) { + // The current deadline is already sooner than the new one. + return context.WithCancel(parent) + } + ctx := &timerCtx{clock: fc, parent: parent, deadline: deadline, done: make(chan struct{})} + propagateCancel(parent, ctx) + dur := deadline.Sub(fc.Now()) + if dur <= 0 { + ctx.cancel(context.DeadlineExceeded) // deadline has already passed + return ctx, func() {} + } + ctx.Lock() + defer ctx.Unlock() + if ctx.err == nil { + ctx.timer = fc.AfterFunc(dur, func() { + ctx.cancel(context.DeadlineExceeded) + }) + } + return ctx, func() { ctx.cancel(context.Canceled) } +} + +// propagateCancel arranges for child to be canceled when parent is. +func propagateCancel(parent context.Context, child *timerCtx) { + if parent.Done() == nil { + return // parent is never canceled + } + go func() { + select { + case <-parent.Done(): + child.cancel(parent.Err()) + case <-child.Done(): + } + }() +} + +type timerCtx struct { + sync.Mutex + + clock Clock + parent context.Context + deadline time.Time + done chan struct{} + + err error + timer Timer +} + +func (c *timerCtx) cancel(err error) { + c.Lock() + defer c.Unlock() + if c.err != nil { + return // already canceled + } + c.err = err + close(c.done) + if c.timer != nil { + c.timer.Stop() + c.timer = nil + } +} + +func (c *timerCtx) Deadline() (deadline time.Time, ok bool) { return c.deadline, true } + +func (c *timerCtx) Done() <-chan struct{} { return c.done } + +func (c *timerCtx) Err() error { return c.err } + +func (c *timerCtx) Value(key interface{}) interface{} { return c.parent.Value(key) } + +func (c *timerCtx) String() string { + return fmt.Sprintf("clock.WithDeadline(%s [%s])", c.deadline, c.deadline.Sub(c.clock.Now())) +} diff --git a/context_test.go b/context_test.go index ee10d5b..6e3d708 100644 --- a/context_test.go +++ b/context_test.go @@ -2,8 +2,10 @@ package clockwork import ( "context" + "errors" "reflect" "testing" + "time" ) func TestContextOps(t *testing.T) { @@ -24,3 +26,95 @@ func assertIsType(t *testing.T, expectedType, object interface{}) { t.Fatalf("Object expected to be of type %v, but was %v", reflect.TypeOf(expectedType), reflect.TypeOf(object)) } } + +// Ensure that WithDeadline is cancelled when deadline exceeded. +func TestFakeClock_WithDeadline(t *testing.T) { + m := NewFakeClock() + now := m.Now() + ctx, _ := m.WithDeadline(context.Background(), now.Add(time.Second)) + m.Advance(time.Second) + select { + case <-ctx.Done(): + if !errors.Is(ctx.Err(), context.DeadlineExceeded) { + t.Error("invalid type of error returned when deadline exceeded") + } + default: + t.Error("context is not cancelled when deadline exceeded") + } +} + +// Ensure that WithDeadline does nothing when the deadline is later than the current deadline. +func TestFakeClock_WithDeadlineLaterThanCurrent(t *testing.T) { + m := NewFakeClock() + ctx, _ := m.WithDeadline(context.Background(), m.Now().Add(time.Second)) + ctx, _ = m.WithDeadline(ctx, m.Now().Add(10*time.Second)) + m.Advance(time.Second) + select { + case <-ctx.Done(): + if !errors.Is(ctx.Err(), context.DeadlineExceeded) { + t.Error("invalid type of error returned when deadline exceeded") + } + default: + t.Error("context is not cancelled when deadline exceeded") + } +} + +// Ensure that WithDeadline cancel closes Done channel with context.Canceled error. +func TestFakeClock_WithDeadlineCancel(t *testing.T) { + m := NewFakeClock() + ctx, cancel := m.WithDeadline(context.Background(), m.Now().Add(time.Second)) + cancel() + select { + case <-ctx.Done(): + if !errors.Is(ctx.Err(), context.Canceled) { + t.Error("invalid type of error returned after cancellation") + } + case <-time.After(time.Second): + t.Error("context is not cancelled after cancel was called") + } +} + +// Ensure that WithDeadline closes child contexts after it was closed. +func TestFakeClock_WithDeadlineCancelledWithParent(t *testing.T) { + m := NewFakeClock() + parent, cancel := context.WithCancel(context.Background()) + ctx, _ := m.WithDeadline(parent, m.Now().Add(time.Second)) + cancel() + select { + case <-ctx.Done(): + if !errors.Is(ctx.Err(), context.Canceled) { + t.Error("invalid type of error returned after cancellation") + } + case <-time.After(time.Second): + t.Error("context is not cancelled when parent context is cancelled") + } +} + +// Ensure that WithDeadline cancelled immediately when deadline has already passed. +func TestFakeClock_WithDeadlineImmediate(t *testing.T) { + m := NewFakeClock() + ctx, _ := m.WithDeadline(context.Background(), m.Now().Add(-time.Second)) + select { + case <-ctx.Done(): + if !errors.Is(ctx.Err(), context.DeadlineExceeded) { + t.Error("invalid type of error returned when deadline has already passed") + } + default: + t.Error("context is not cancelled when deadline has already passed") + } +} + +// Ensure that WithTimeout is cancelled when deadline exceeded. +func TestFakeClock_WithTimeout(t *testing.T) { + m := NewFakeClock() + ctx, _ := m.WithTimeout(context.Background(), time.Second) + m.Advance(time.Second) + select { + case <-ctx.Done(): + if !errors.Is(ctx.Err(), context.DeadlineExceeded) { + t.Error("invalid type of error returned when time is over") + } + default: + t.Error("context is not cancelled when time is over") + } +} diff --git a/timer.go b/timer.go index 4e28e63..322ea4c 100644 --- a/timer.go +++ b/timer.go @@ -1,6 +1,8 @@ package clockwork -import "time" +import ( + "time" +) // Timer provides an interface which can be used instead of directly using // [time.Timer]. The real-time timer t provides events through t.C which becomes @@ -40,6 +42,7 @@ func (f *fakeTimer) Stop() bool { func (f *fakeTimer) expire(now time.Time) *time.Duration { if f.afterFunc != nil { + defer gosched() go f.afterFunc() return nil } @@ -51,3 +54,6 @@ func (f *fakeTimer) expire(now time.Time) *time.Duration { } return nil } + +// Sleep momentarily so that other goroutines can process. +func gosched() { time.Sleep(time.Millisecond) } From 69bdf3823b3b839af254985de4d9f3666d63d930 Mon Sep 17 00:00:00 2001 From: Ilya Ozherelyev Date: Fri, 13 Sep 2024 16:50:00 +0200 Subject: [PATCH 2/2] alternative approach --- clockwork.go | 2 -- context.go | 49 +++++++++++++++++++++++++++++-------------------- context_test.go | 20 +++++++++++++------- 3 files changed, 42 insertions(+), 29 deletions(-) diff --git a/clockwork.go b/clockwork.go index f8a8ada..df02ba9 100644 --- a/clockwork.go +++ b/clockwork.go @@ -19,8 +19,6 @@ type Clock interface { NewTicker(d time.Duration) Ticker NewTimer(d time.Duration) Timer AfterFunc(d time.Duration, f func()) Timer - WithTimeout(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) - WithDeadline(parent context.Context, deadline time.Time) (context.Context, context.CancelFunc) } // NewRealClock returns a Clock which simply delegates calls to the actual time diff --git a/context.go b/context.go index c3bc9bf..3915294 100644 --- a/context.go +++ b/context.go @@ -32,38 +32,47 @@ func FromContext(ctx context.Context) Clock { return NewRealClock() } -func (rc *realClock) WithTimeout(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { - return context.WithTimeout(parent, timeout) +// ClockedContext is a interface that extends the context.Context interface with +// methods for creating new contexts with timeouts and deadlines with a controlled clock. +type ClockedContext interface { + context.Context + WithTimeout(parent context.Context, timeout time.Duration) (ClockedContext, context.CancelFunc) + WithDeadline(parent context.Context, deadline time.Time) (ClockedContext, context.CancelFunc) } -func (rc *realClock) WithDeadline(parent context.Context, deadline time.Time) (context.Context, context.CancelFunc) { - return context.WithDeadline(parent, deadline) +// WrapContext creates a new context that uses the provided clock for timeouts and deadlines. +func WrapContext(parent context.Context, clock Clock) ClockedContext { + ctx := &timerCtx{ + clock: clock, + parent: parent, + done: make(chan struct{}), + } + propagateCancel(parent, ctx) + return ctx } -func (fc *FakeClock) WithTimeout(parent context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { - return fc.WithDeadline(parent, fc.Now().Add(timeout)) +func (c *timerCtx) WithTimeout(parent context.Context, timeout time.Duration) (ClockedContext, context.CancelFunc) { + return c.WithDeadline(parent, c.clock.Now().Add(timeout)) } -func (fc *FakeClock) WithDeadline(parent context.Context, deadline time.Time) (context.Context, context.CancelFunc) { +func (c *timerCtx) WithDeadline(parent context.Context, deadline time.Time) (ClockedContext, context.CancelFunc) { if cur, ok := parent.Deadline(); ok && cur.Before(deadline) { // The current deadline is already sooner than the new one. - return context.WithCancel(parent) + return c, func() { c.cancel(parent.Err()) } } - ctx := &timerCtx{clock: fc, parent: parent, deadline: deadline, done: make(chan struct{})} - propagateCancel(parent, ctx) - dur := deadline.Sub(fc.Now()) + dur := deadline.Sub(c.clock.Now()) if dur <= 0 { - ctx.cancel(context.DeadlineExceeded) // deadline has already passed - return ctx, func() {} + c.cancel(context.DeadlineExceeded) // deadline has already passed + return c, func() {} } - ctx.Lock() - defer ctx.Unlock() - if ctx.err == nil { - ctx.timer = fc.AfterFunc(dur, func() { - ctx.cancel(context.DeadlineExceeded) + c.Lock() + defer c.Unlock() + if c.err == nil { + c.timer = c.clock.AfterFunc(dur, func() { + c.cancel(context.DeadlineExceeded) }) } - return ctx, func() { ctx.cancel(context.Canceled) } + return c, func() { c.cancel(context.Canceled) } } // propagateCancel arranges for child to be canceled when parent is. @@ -106,7 +115,7 @@ func (c *timerCtx) cancel(err error) { } } -func (c *timerCtx) Deadline() (deadline time.Time, ok bool) { return c.deadline, true } +func (c *timerCtx) Deadline() (deadline time.Time, ok bool) { return c.deadline, !c.deadline.IsZero() } func (c *timerCtx) Done() <-chan struct{} { return c.done } diff --git a/context_test.go b/context_test.go index 6e3d708..b620d09 100644 --- a/context_test.go +++ b/context_test.go @@ -31,7 +31,8 @@ func assertIsType(t *testing.T, expectedType, object interface{}) { func TestFakeClock_WithDeadline(t *testing.T) { m := NewFakeClock() now := m.Now() - ctx, _ := m.WithDeadline(context.Background(), now.Add(time.Second)) + wrappedContext := WrapContext(context.Background(), m) + ctx, _ := wrappedContext.WithDeadline(wrappedContext, now.Add(time.Second)) m.Advance(time.Second) select { case <-ctx.Done(): @@ -46,8 +47,9 @@ func TestFakeClock_WithDeadline(t *testing.T) { // Ensure that WithDeadline does nothing when the deadline is later than the current deadline. func TestFakeClock_WithDeadlineLaterThanCurrent(t *testing.T) { m := NewFakeClock() - ctx, _ := m.WithDeadline(context.Background(), m.Now().Add(time.Second)) - ctx, _ = m.WithDeadline(ctx, m.Now().Add(10*time.Second)) + wrappedContext := WrapContext(context.Background(), m) + ctx, _ := wrappedContext.WithDeadline(wrappedContext, m.Now().Add(time.Second)) + ctx, _ = wrappedContext.WithDeadline(ctx, m.Now().Add(10*time.Second)) m.Advance(time.Second) select { case <-ctx.Done(): @@ -62,7 +64,8 @@ func TestFakeClock_WithDeadlineLaterThanCurrent(t *testing.T) { // Ensure that WithDeadline cancel closes Done channel with context.Canceled error. func TestFakeClock_WithDeadlineCancel(t *testing.T) { m := NewFakeClock() - ctx, cancel := m.WithDeadline(context.Background(), m.Now().Add(time.Second)) + wrappedContext := WrapContext(context.Background(), m) + ctx, cancel := wrappedContext.WithDeadline(context.Background(), m.Now().Add(time.Second)) cancel() select { case <-ctx.Done(): @@ -78,7 +81,8 @@ func TestFakeClock_WithDeadlineCancel(t *testing.T) { func TestFakeClock_WithDeadlineCancelledWithParent(t *testing.T) { m := NewFakeClock() parent, cancel := context.WithCancel(context.Background()) - ctx, _ := m.WithDeadline(parent, m.Now().Add(time.Second)) + wrappedContext := WrapContext(parent, m) + ctx, _ := wrappedContext.WithDeadline(parent, m.Now().Add(time.Second)) cancel() select { case <-ctx.Done(): @@ -93,7 +97,8 @@ func TestFakeClock_WithDeadlineCancelledWithParent(t *testing.T) { // Ensure that WithDeadline cancelled immediately when deadline has already passed. func TestFakeClock_WithDeadlineImmediate(t *testing.T) { m := NewFakeClock() - ctx, _ := m.WithDeadline(context.Background(), m.Now().Add(-time.Second)) + wrappedContext := WrapContext(context.Background(), m) + ctx, _ := wrappedContext.WithDeadline(context.Background(), m.Now().Add(-time.Second)) select { case <-ctx.Done(): if !errors.Is(ctx.Err(), context.DeadlineExceeded) { @@ -107,7 +112,8 @@ func TestFakeClock_WithDeadlineImmediate(t *testing.T) { // Ensure that WithTimeout is cancelled when deadline exceeded. func TestFakeClock_WithTimeout(t *testing.T) { m := NewFakeClock() - ctx, _ := m.WithTimeout(context.Background(), time.Second) + wrappedContext := WrapContext(context.Background(), m) + ctx, _ := wrappedContext.WithTimeout(context.Background(), time.Second) m.Advance(time.Second) select { case <-ctx.Done():