-
Notifications
You must be signed in to change notification settings - Fork 2
/
group_test.go
280 lines (258 loc) · 8.49 KB
/
group_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
package parallel
import (
"context"
"runtime"
"sync"
"sync/atomic"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type contextLeak struct {
lock sync.Mutex
ctxs []context.Context
}
func (c *contextLeak) leak(ctx context.Context) {
c.lock.Lock()
defer c.lock.Unlock()
c.ctxs = append(c.ctxs, ctx)
}
func (c *contextLeak) assertAllCanceled(t *testing.T, expected ...error) {
t.Helper()
if len(expected) > 1 {
panic("please just provide 1 expected error for all the contexts")
}
c.lock.Lock()
defer c.lock.Unlock()
for _, ctx := range c.ctxs {
cause := context.Cause(ctx)
if cause == nil {
t.Fatal("context was not canceled")
}
if len(expected) == 1 {
require.ErrorIs(t, cause, expected[0])
}
}
}
func assertPanicsWithValue(t *testing.T, expectedValue any, f func()) {
t.Helper()
defer func() {
p := recover()
if p == nil {
t.Fatal("didn't panic but should have")
}
assert.Equal(t, expectedValue, p.(WorkerPanic).Panic)
}()
f()
}
func TestGroup(t *testing.T) {
for _, test := range []struct {
name string
makeExec func(context.Context) Executor
}{
{"Unlimited", Unlimited},
{"Limited", func(ctx context.Context) Executor { return Limited(ctx, 10) }},
{"serial", func(ctx context.Context) Executor { return Limited(ctx, 0) }},
} {
t.Run(test.name, func(t *testing.T) {
testGroup(t, test.makeExec)
})
}
}
func testGroup(t *testing.T, makeExec func(context.Context) Executor) {
t.Parallel()
t.Run("do nothing", func(t *testing.T) {
t.Parallel()
g := makeExec(context.Background())
g.Wait()
})
t.Run("do nothing canceled", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
g := makeExec(ctx)
cancel()
g.Wait()
})
t.Run("sum 100", func(t *testing.T) {
t.Parallel()
var counter int64
var leak contextLeak
g := makeExec(context.Background())
for i := 0; i < 100; i++ {
g.Go(func(ctx context.Context) {
leak.leak(ctx)
atomic.AddInt64(&counter, 1)
})
}
g.Wait()
assert.Equal(t, int64(100), counter)
leak.assertAllCanceled(t, errGroupDone)
})
t.Run("sum canceled", func(t *testing.T) {
t.Parallel()
var counter int64
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
g := makeExec(ctx)
for i := 0; i < 100; i++ {
if i == 50 {
cancel()
}
g.Go(func(context.Context) {
atomic.AddInt64(&counter, 1)
})
}
g.Wait()
// Work submitted after the context has been canceled does not happen.
// We cannot guarantee that the counter isn't less than 50, because some
// of the original 50 work units might not have started yet. We also
// cannot guarantee that the counter isn't *more* than 50 because in the
// limited executor, some of the worker functions may select a work item
// instead of seeing the done signal on their final loop.
var maxSum int64 = 50
if lg, ok := g.(*limitedGroup); ok {
maxSum += int64(lg.max) // limitedGroup may run up to 1 more per worker
}
assert.LessOrEqual(t, counter, maxSum)
})
t.Run("wait multiple times", func(t *testing.T) {
t.Parallel()
g := makeExec(context.Background())
assert.NotPanics(t, g.Wait)
assert.NotPanics(t, g.Wait)
})
}
func testLimitedGroupMaxConcurrency(t *testing.T, name string, g Executor, limit int, shouldSucceed bool) {
// Testing that some process can work with *at least* N parallelism is easy:
// we run N jobs that cannot make progress, and unblock them when they have
// all arrived at that blocker.
//
// Coming up with a way to validate that something runs with *NO MORE THAN*
// N parallelism is HARD.
//
// We can't just time.Sleep and wait for everything to catch up, because
// that simply isn't how concurrency works, especially in test environments:
// there's no amount of time we can choose that will actually guarantee
// another thread has caught up. So instead, we first assert that exactly N
// jobs are running in the executor in parallel, and then we insert lots and
// lots of poison pills into the work queue and *footrace* with any other
// worker threads that might have started that could be trying to run jobs,
// while also reaching under the hood and discarding those work units
// ourselves. Golang channels are sufficiently fair such that if there are
// multiple waiters all of them will get at least *some* of the items in the
// channel eventually, which gives us a very high probability that any such
// worker will choke on a poison pill if it exists.
t.Run(name, func(t *testing.T) {
t.Parallel()
var blocker, barrier sync.WaitGroup
// Blocker stops the workers from progressing
blocker.Add(1)
// Barrier lets us know when all the workers have arrived. If this
// test hangs, probably it's because not enough workers started.
barrier.Add(limit)
jobInserter := Unlimited(context.Background())
jobInserter.Go(func(context.Context) {
// We fully loop over the ops channel in the test to empty it. The
// channel is only closed when the group is awaited or forgotten but
// not when it panics, and just guaranteeing we await it takes the
// least code, so we do that.
defer g.Wait()
for i := 0; i < limit; i++ {
g.Go(func(context.Context) {
barrier.Done()
blocker.Wait()
})
}
// Now we insert a whole buttload of jobs that should never be picked
// up and run by the executor. We will go through and consume these
// from the channel ourselves in the main thread, but if there were
// any workers taking from that channel chances are they would get
// and run at least one of these jobs, failing the test.
for i := 0; i < 10000; i++ {
g.Go(func(context.Context) {
panic("poison pill")
})
}
g.Wait()
})
barrier.Wait()
// All the workers we *expect* to see have shown up now. Throw away all
// the poison pills in the ops queue
for poisonPill := range g.(*limitedGroup).ops {
runtime.Gosched() // Trigger preemption as much as we can
assert.NotNil(t, poisonPill)
runtime.Gosched() // Trigger preemption as much as we can
}
blocker.Done() // unblock the workers
if shouldSucceed {
assert.NotPanics(t, jobInserter.Wait)
} else {
assertPanicsWithValue(t, "poison pill", jobInserter.Wait)
}
})
}
func TestLimitedGroupMaxConcurrency(t *testing.T) {
t.Parallel()
testLimitedGroupMaxConcurrency(t, "100", Limited(context.Background(), 100), 100, true)
testLimitedGroupMaxConcurrency(t, "50", Limited(context.Background(), 50), 50, true)
testLimitedGroupMaxConcurrency(t, "5", Limited(context.Background(), 5), 5, true)
testLimitedGroupMaxConcurrency(t, "1", Limited(context.Background(), 1), 1, true)
// Validate the test
testLimitedGroupMaxConcurrency(t, "fail", Limited(context.Background(), 6), 5, false)
}
func TestConcurrentGroupWaitReallyWaits(t *testing.T) {
testConcurrentGroupWaitReallyWaits(t, "Unlimited", Unlimited(context.Background()))
testConcurrentGroupWaitReallyWaits(t, "Limited", Limited(context.Background(), 2))
}
func testConcurrentGroupWaitReallyWaits(t *testing.T, name string, g Executor) {
const parallelWaiters = 100
t.Run(name, func(t *testing.T) {
var blocker sync.WaitGroup
blocker.Add(1)
g.Go(func(context.Context) {
blocker.Wait()
})
failureCanary := make(chan struct{}, parallelWaiters)
// Wait for the group many times concurrently
testingGroup := Unlimited(context.Background())
for i := 0; i < parallelWaiters; i++ {
testingGroup.Go(func(context.Context) {
g.Wait()
failureCanary <- struct{}{}
})
}
// Give the testing group lots and lots of chances to make progress
for i := 0; i < 100000; i++ {
select {
case <-failureCanary:
t.Fatal("a Wait() call made progress when it shouldn't!")
default:
}
runtime.Gosched()
}
// Clean up
blocker.Done()
for i := 0; i < parallelWaiters; i++ {
<-failureCanary
}
testingGroup.Wait()
})
}
func TestCanGoexit(t *testing.T) {
g := Unlimited(context.Background())
g.Go(func(context.Context) {
// Ideally we would test t.Fatal() here to show that parallel now plays
// nicely with the testing lib, but there doesn't seem to be any good
// way to xfail a golang test. As it happens t.Fatal() just sets a fail
// flag and then calls Goexit() anyway; if we treat nil recover() values
// as Goexit() (guaranteed since 1.21 with the advent of PanicNilError)
// we can handle this very simply, without needing a "double defer
// sandwich".
//
// Either way, we expect Goexit() to work normally in tests now and not
// fail or re-panic.
runtime.Goexit()
})
g.Wait()
}