diff --git a/context.go b/context.go index 40906e7..3915294 100644 --- a/context.go +++ b/context.go @@ -2,6 +2,9 @@ package clockwork import ( "context" + "fmt" + "sync" + "time" ) // contextKey is private to this package so we can ensure uniqueness here. This @@ -28,3 +31,98 @@ func FromContext(ctx context.Context) Clock { } return NewRealClock() } + +// ClockedContext is a interface that extends the context.Context interface with +// methods for creating new contexts with timeouts and deadlines with a controlled clock. +type ClockedContext interface { + context.Context + WithTimeout(parent context.Context, timeout time.Duration) (ClockedContext, context.CancelFunc) + WithDeadline(parent context.Context, deadline time.Time) (ClockedContext, context.CancelFunc) +} + +// WrapContext creates a new context that uses the provided clock for timeouts and deadlines. +func WrapContext(parent context.Context, clock Clock) ClockedContext { + ctx := &timerCtx{ + clock: clock, + parent: parent, + done: make(chan struct{}), + } + propagateCancel(parent, ctx) + return ctx +} + +func (c *timerCtx) WithTimeout(parent context.Context, timeout time.Duration) (ClockedContext, context.CancelFunc) { + return c.WithDeadline(parent, c.clock.Now().Add(timeout)) +} + +func (c *timerCtx) WithDeadline(parent context.Context, deadline time.Time) (ClockedContext, context.CancelFunc) { + if cur, ok := parent.Deadline(); ok && cur.Before(deadline) { + // The current deadline is already sooner than the new one. + return c, func() { c.cancel(parent.Err()) } + } + dur := deadline.Sub(c.clock.Now()) + if dur <= 0 { + c.cancel(context.DeadlineExceeded) // deadline has already passed + return c, func() {} + } + c.Lock() + defer c.Unlock() + if c.err == nil { + c.timer = c.clock.AfterFunc(dur, func() { + c.cancel(context.DeadlineExceeded) + }) + } + return c, func() { c.cancel(context.Canceled) } +} + +// propagateCancel arranges for child to be canceled when parent is. +func propagateCancel(parent context.Context, child *timerCtx) { + if parent.Done() == nil { + return // parent is never canceled + } + go func() { + select { + case <-parent.Done(): + child.cancel(parent.Err()) + case <-child.Done(): + } + }() +} + +type timerCtx struct { + sync.Mutex + + clock Clock + parent context.Context + deadline time.Time + done chan struct{} + + err error + timer Timer +} + +func (c *timerCtx) cancel(err error) { + c.Lock() + defer c.Unlock() + if c.err != nil { + return // already canceled + } + c.err = err + close(c.done) + if c.timer != nil { + c.timer.Stop() + c.timer = nil + } +} + +func (c *timerCtx) Deadline() (deadline time.Time, ok bool) { return c.deadline, !c.deadline.IsZero() } + +func (c *timerCtx) Done() <-chan struct{} { return c.done } + +func (c *timerCtx) Err() error { return c.err } + +func (c *timerCtx) Value(key interface{}) interface{} { return c.parent.Value(key) } + +func (c *timerCtx) String() string { + return fmt.Sprintf("clock.WithDeadline(%s [%s])", c.deadline, c.deadline.Sub(c.clock.Now())) +} diff --git a/context_test.go b/context_test.go index ee10d5b..b620d09 100644 --- a/context_test.go +++ b/context_test.go @@ -2,8 +2,10 @@ package clockwork import ( "context" + "errors" "reflect" "testing" + "time" ) func TestContextOps(t *testing.T) { @@ -24,3 +26,101 @@ func assertIsType(t *testing.T, expectedType, object interface{}) { t.Fatalf("Object expected to be of type %v, but was %v", reflect.TypeOf(expectedType), reflect.TypeOf(object)) } } + +// Ensure that WithDeadline is cancelled when deadline exceeded. +func TestFakeClock_WithDeadline(t *testing.T) { + m := NewFakeClock() + now := m.Now() + wrappedContext := WrapContext(context.Background(), m) + ctx, _ := wrappedContext.WithDeadline(wrappedContext, now.Add(time.Second)) + m.Advance(time.Second) + select { + case <-ctx.Done(): + if !errors.Is(ctx.Err(), context.DeadlineExceeded) { + t.Error("invalid type of error returned when deadline exceeded") + } + default: + t.Error("context is not cancelled when deadline exceeded") + } +} + +// Ensure that WithDeadline does nothing when the deadline is later than the current deadline. +func TestFakeClock_WithDeadlineLaterThanCurrent(t *testing.T) { + m := NewFakeClock() + wrappedContext := WrapContext(context.Background(), m) + ctx, _ := wrappedContext.WithDeadline(wrappedContext, m.Now().Add(time.Second)) + ctx, _ = wrappedContext.WithDeadline(ctx, m.Now().Add(10*time.Second)) + m.Advance(time.Second) + select { + case <-ctx.Done(): + if !errors.Is(ctx.Err(), context.DeadlineExceeded) { + t.Error("invalid type of error returned when deadline exceeded") + } + default: + t.Error("context is not cancelled when deadline exceeded") + } +} + +// Ensure that WithDeadline cancel closes Done channel with context.Canceled error. +func TestFakeClock_WithDeadlineCancel(t *testing.T) { + m := NewFakeClock() + wrappedContext := WrapContext(context.Background(), m) + ctx, cancel := wrappedContext.WithDeadline(context.Background(), m.Now().Add(time.Second)) + cancel() + select { + case <-ctx.Done(): + if !errors.Is(ctx.Err(), context.Canceled) { + t.Error("invalid type of error returned after cancellation") + } + case <-time.After(time.Second): + t.Error("context is not cancelled after cancel was called") + } +} + +// Ensure that WithDeadline closes child contexts after it was closed. +func TestFakeClock_WithDeadlineCancelledWithParent(t *testing.T) { + m := NewFakeClock() + parent, cancel := context.WithCancel(context.Background()) + wrappedContext := WrapContext(parent, m) + ctx, _ := wrappedContext.WithDeadline(parent, m.Now().Add(time.Second)) + cancel() + select { + case <-ctx.Done(): + if !errors.Is(ctx.Err(), context.Canceled) { + t.Error("invalid type of error returned after cancellation") + } + case <-time.After(time.Second): + t.Error("context is not cancelled when parent context is cancelled") + } +} + +// Ensure that WithDeadline cancelled immediately when deadline has already passed. +func TestFakeClock_WithDeadlineImmediate(t *testing.T) { + m := NewFakeClock() + wrappedContext := WrapContext(context.Background(), m) + ctx, _ := wrappedContext.WithDeadline(context.Background(), m.Now().Add(-time.Second)) + select { + case <-ctx.Done(): + if !errors.Is(ctx.Err(), context.DeadlineExceeded) { + t.Error("invalid type of error returned when deadline has already passed") + } + default: + t.Error("context is not cancelled when deadline has already passed") + } +} + +// Ensure that WithTimeout is cancelled when deadline exceeded. +func TestFakeClock_WithTimeout(t *testing.T) { + m := NewFakeClock() + wrappedContext := WrapContext(context.Background(), m) + ctx, _ := wrappedContext.WithTimeout(context.Background(), time.Second) + m.Advance(time.Second) + select { + case <-ctx.Done(): + if !errors.Is(ctx.Err(), context.DeadlineExceeded) { + t.Error("invalid type of error returned when time is over") + } + default: + t.Error("context is not cancelled when time is over") + } +} diff --git a/timer.go b/timer.go index 4e28e63..322ea4c 100644 --- a/timer.go +++ b/timer.go @@ -1,6 +1,8 @@ package clockwork -import "time" +import ( + "time" +) // Timer provides an interface which can be used instead of directly using // [time.Timer]. The real-time timer t provides events through t.C which becomes @@ -40,6 +42,7 @@ func (f *fakeTimer) Stop() bool { func (f *fakeTimer) expire(now time.Time) *time.Duration { if f.afterFunc != nil { + defer gosched() go f.afterFunc() return nil } @@ -51,3 +54,6 @@ func (f *fakeTimer) expire(now time.Time) *time.Duration { } return nil } + +// Sleep momentarily so that other goroutines can process. +func gosched() { time.Sleep(time.Millisecond) }