Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce context aware helpers for context related testing #86

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package clockwork

import (
"context"
"fmt"
"sync"
"time"
)

// contextKey is private to this package so we can ensure uniqueness here. This
Expand All @@ -28,3 +31,98 @@ func FromContext(ctx context.Context) Clock {
}
return NewRealClock()
}

// ClockedContext is a interface that extends the context.Context interface with
// methods for creating new contexts with timeouts and deadlines with a controlled clock.
type ClockedContext interface {
context.Context
WithTimeout(parent context.Context, timeout time.Duration) (ClockedContext, context.CancelFunc)
WithDeadline(parent context.Context, deadline time.Time) (ClockedContext, context.CancelFunc)
}

// WrapContext creates a new context that uses the provided clock for timeouts and deadlines.
func WrapContext(parent context.Context, clock Clock) ClockedContext {
ctx := &timerCtx{
clock: clock,
parent: parent,
done: make(chan struct{}),
}
propagateCancel(parent, ctx)
return ctx
}

func (c *timerCtx) WithTimeout(parent context.Context, timeout time.Duration) (ClockedContext, context.CancelFunc) {
return c.WithDeadline(parent, c.clock.Now().Add(timeout))
}

func (c *timerCtx) WithDeadline(parent context.Context, deadline time.Time) (ClockedContext, context.CancelFunc) {
if cur, ok := parent.Deadline(); ok && cur.Before(deadline) {
// The current deadline is already sooner than the new one.
return c, func() { c.cancel(parent.Err()) }
}
dur := deadline.Sub(c.clock.Now())
if dur <= 0 {
c.cancel(context.DeadlineExceeded) // deadline has already passed
return c, func() {}
}
c.Lock()
defer c.Unlock()
if c.err == nil {
c.timer = c.clock.AfterFunc(dur, func() {
c.cancel(context.DeadlineExceeded)
})
}
return c, func() { c.cancel(context.Canceled) }
}

// propagateCancel arranges for child to be canceled when parent is.
func propagateCancel(parent context.Context, child *timerCtx) {
if parent.Done() == nil {
return // parent is never canceled
}
go func() {
select {
case <-parent.Done():
child.cancel(parent.Err())
case <-child.Done():
}
}()
}

type timerCtx struct {
sync.Mutex

clock Clock
parent context.Context
deadline time.Time
done chan struct{}

err error
timer Timer
}

func (c *timerCtx) cancel(err error) {
c.Lock()
defer c.Unlock()
if c.err != nil {
return // already canceled
}
c.err = err
close(c.done)
if c.timer != nil {
c.timer.Stop()
c.timer = nil
}
}

func (c *timerCtx) Deadline() (deadline time.Time, ok bool) { return c.deadline, !c.deadline.IsZero() }

func (c *timerCtx) Done() <-chan struct{} { return c.done }

func (c *timerCtx) Err() error { return c.err }

func (c *timerCtx) Value(key interface{}) interface{} { return c.parent.Value(key) }

func (c *timerCtx) String() string {
return fmt.Sprintf("clock.WithDeadline(%s [%s])", c.deadline, c.deadline.Sub(c.clock.Now()))
}
100 changes: 100 additions & 0 deletions context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package clockwork

import (
"context"
"errors"
"reflect"
"testing"
"time"
)

func TestContextOps(t *testing.T) {
Expand All @@ -24,3 +26,101 @@ func assertIsType(t *testing.T, expectedType, object interface{}) {
t.Fatalf("Object expected to be of type %v, but was %v", reflect.TypeOf(expectedType), reflect.TypeOf(object))
}
}

// Ensure that WithDeadline is cancelled when deadline exceeded.
func TestFakeClock_WithDeadline(t *testing.T) {
m := NewFakeClock()
now := m.Now()
wrappedContext := WrapContext(context.Background(), m)
ctx, _ := wrappedContext.WithDeadline(wrappedContext, now.Add(time.Second))
m.Advance(time.Second)
select {
case <-ctx.Done():
if !errors.Is(ctx.Err(), context.DeadlineExceeded) {
t.Error("invalid type of error returned when deadline exceeded")
}
default:
t.Error("context is not cancelled when deadline exceeded")
}
}

// Ensure that WithDeadline does nothing when the deadline is later than the current deadline.
func TestFakeClock_WithDeadlineLaterThanCurrent(t *testing.T) {
m := NewFakeClock()
wrappedContext := WrapContext(context.Background(), m)
ctx, _ := wrappedContext.WithDeadline(wrappedContext, m.Now().Add(time.Second))
ctx, _ = wrappedContext.WithDeadline(ctx, m.Now().Add(10*time.Second))
m.Advance(time.Second)
select {
case <-ctx.Done():
if !errors.Is(ctx.Err(), context.DeadlineExceeded) {
t.Error("invalid type of error returned when deadline exceeded")
}
default:
t.Error("context is not cancelled when deadline exceeded")
}
}

// Ensure that WithDeadline cancel closes Done channel with context.Canceled error.
func TestFakeClock_WithDeadlineCancel(t *testing.T) {
m := NewFakeClock()
wrappedContext := WrapContext(context.Background(), m)
ctx, cancel := wrappedContext.WithDeadline(context.Background(), m.Now().Add(time.Second))
cancel()
select {
case <-ctx.Done():
if !errors.Is(ctx.Err(), context.Canceled) {
t.Error("invalid type of error returned after cancellation")
}
case <-time.After(time.Second):
t.Error("context is not cancelled after cancel was called")
}
}

// Ensure that WithDeadline closes child contexts after it was closed.
func TestFakeClock_WithDeadlineCancelledWithParent(t *testing.T) {
m := NewFakeClock()
parent, cancel := context.WithCancel(context.Background())
wrappedContext := WrapContext(parent, m)
ctx, _ := wrappedContext.WithDeadline(parent, m.Now().Add(time.Second))
cancel()
select {
case <-ctx.Done():
if !errors.Is(ctx.Err(), context.Canceled) {
t.Error("invalid type of error returned after cancellation")
}
case <-time.After(time.Second):
t.Error("context is not cancelled when parent context is cancelled")
}
}

// Ensure that WithDeadline cancelled immediately when deadline has already passed.
func TestFakeClock_WithDeadlineImmediate(t *testing.T) {
m := NewFakeClock()
wrappedContext := WrapContext(context.Background(), m)
ctx, _ := wrappedContext.WithDeadline(context.Background(), m.Now().Add(-time.Second))
select {
case <-ctx.Done():
if !errors.Is(ctx.Err(), context.DeadlineExceeded) {
t.Error("invalid type of error returned when deadline has already passed")
}
default:
t.Error("context is not cancelled when deadline has already passed")
}
}

// Ensure that WithTimeout is cancelled when deadline exceeded.
func TestFakeClock_WithTimeout(t *testing.T) {
m := NewFakeClock()
wrappedContext := WrapContext(context.Background(), m)
ctx, _ := wrappedContext.WithTimeout(context.Background(), time.Second)
m.Advance(time.Second)
select {
case <-ctx.Done():
if !errors.Is(ctx.Err(), context.DeadlineExceeded) {
t.Error("invalid type of error returned when time is over")
}
default:
t.Error("context is not cancelled when time is over")
}
}
8 changes: 7 additions & 1 deletion timer.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package clockwork

import "time"
import (
"time"
)

// Timer provides an interface which can be used instead of directly using
// [time.Timer]. The real-time timer t provides events through t.C which becomes
Expand Down Expand Up @@ -40,6 +42,7 @@ func (f *fakeTimer) Stop() bool {

func (f *fakeTimer) expire(now time.Time) *time.Duration {
if f.afterFunc != nil {
defer gosched()
go f.afterFunc()
return nil
}
Expand All @@ -51,3 +54,6 @@ func (f *fakeTimer) expire(now time.Time) *time.Duration {
}
return nil
}

// Sleep momentarily so that other goroutines can process.
func gosched() { time.Sleep(time.Millisecond) }