diff --git a/go.mod b/go.mod index aba2d27..8215e37 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,8 @@ -module github.com/anantadwi13/worker +module github.com/anantadwi13/workers go 1.14 -require github.com/google/uuid v1.3.0 +require ( + github.com/google/uuid v1.3.0 + github.com/stretchr/testify v1.7.1 +) diff --git a/go.sum b/go.sum index 3dfe1c9..8378110 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,13 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/job.go b/job.go index 6347662..bad5df5 100644 --- a/job.go +++ b/job.go @@ -1,4 +1,4 @@ -package worker +package workers import ( "context" @@ -8,11 +8,15 @@ import ( ) type simpleJob struct { - id string - jobFunc JobFunc - resultChan chan *Result + id string + jobFunc JobFunc + requestData interface{} + resultChan chan *Result - cancelChan chan IsCanceled + status Status + statusLock sync.RWMutex + doneChan chan ChanSignal + cancelChan chan ChanSignal wg sync.WaitGroup } @@ -21,16 +25,19 @@ type Result struct { Error error } -type JobFunc func(isCanceled chan IsCanceled, result chan *Result) +type JobFunc func(isCanceled chan ChanSignal, requestData interface{}, result chan *Result) func NewJobSimple( - jobFunc JobFunc, result chan *Result, + jobFunc JobFunc, requestData interface{}, result chan *Result, ) Job { return &simpleJob{ - id: uuid.NewString(), - jobFunc: jobFunc, - resultChan: result, - cancelChan: make(chan IsCanceled, 2), + id: uuid.NewString(), + jobFunc: jobFunc, + requestData: requestData, + resultChan: result, + status: StatusCreated, + cancelChan: make(chan ChanSignal, 2), + doneChan: make(chan ChanSignal), } } @@ -39,35 +46,61 @@ func (s *simpleJob) Id() string { } func (s *simpleJob) Do(ctx context.Context) { - done := make(chan struct{}) + if s.Status() != StatusCreated { + return + } + + s.setStatus(StatusRunning) + defer s.setStatus(StatusStopped) + + done := make(chan ChanSignal) s.wg.Add(1) go func() { defer s.wg.Done() - s.jobFunc(s.cancelChan, s.resultChan) + s.jobFunc(s.cancelChan, s.requestData, s.resultChan) close(done) }() select { case <-done: - s.wg.Wait() case <-ctx.Done(): - s.cancelChan <- struct{}{} + s.cancelChan <- ChanSignal{} } + + s.wg.Wait() } func (s *simpleJob) Cancel(ctx context.Context) { - done := make(chan struct{}) - - go func() { - s.cancelChan <- struct{}{} - s.wg.Wait() - close(done) - }() + if s.Status() != StatusRunning { + return + } select { - case <-done: + case <-s.doneChan: case <-ctx.Done(): + s.cancelChan <- ChanSignal{} + } + + s.wg.Wait() +} + +func (s *simpleJob) Status() Status { + s.statusLock.RLock() + defer s.statusLock.RUnlock() + return s.status +} + +func (s *simpleJob) Done() chan ChanSignal { + return s.doneChan +} + +func (s *simpleJob) setStatus(status Status) { + s.statusLock.Lock() + defer s.statusLock.Unlock() + s.status = status + if s.status == StatusStopped { + close(s.doneChan) } } diff --git a/type.go b/type.go index 03f686e..4891a72 100644 --- a/type.go +++ b/type.go @@ -1,26 +1,40 @@ -package worker +package workers -import "context" +import ( + "context" + "time" +) + +type ChanSignal struct{} +type Status int + +const ( + StatusCreated Status = iota + StatusRunning + StatusStopped +) type Job interface { Id() string - // ctx contains a job timeout + Status() Status + Done() chan ChanSignal + + // Func below should be called by worker + + // Do should be blocking the process until the job is finished or canceled. ctx contains a job timeout Do(ctx context.Context) - // Cancel will block the process until the job is gracefully canceled. ctx contains a cancellation deadline + // Cancel should be blocking the process until the job is gracefully canceled. ctx contains a cancellation deadline Cancel(ctx context.Context) } type Worker interface { Start() error Shutdown() error + Status() Status - // GetJobTimeout returns a timeout in seconds - GetJobTimeout() int - // GetShutdownTimeout returns a timeout in seconds - GetShutdownTimeout() int + GetJobTimeout() time.Duration + GetShutdownTimeout() time.Duration Push(job Job) error PushAndWait(job Job) error } - -type IsCanceled struct{} diff --git a/worker_pool.go b/worker_pool.go index 8694055..7b26166 100644 --- a/worker_pool.go +++ b/worker_pool.go @@ -1,7 +1,8 @@ -package worker +package workers import ( "context" + "errors" "runtime" "sync" "time" @@ -11,16 +12,21 @@ type WorkerPoolConfig struct { QueueSize int WorkerSize int - // timeout in seconds - JobTimeout int - ShutdownTimeout int + // 0 for infinity time + JobTimeout time.Duration + ShutdownTimeout time.Duration } type workerPool struct { jobQueue chan Job workerSize int - jobTimeout int - shutdownTimeout int + jobTimeout time.Duration + shutdownTimeout time.Duration + + status Status + statusLock sync.RWMutex + + wgShutdown sync.WaitGroup workerCancellationLock sync.Mutex workerCancellationFunc []func() @@ -47,35 +53,67 @@ func NewWorkerPool(config WorkerPoolConfig) (Worker, error) { workerSize: config.WorkerSize, jobTimeout: config.JobTimeout, shutdownTimeout: config.ShutdownTimeout, + status: StatusCreated, }, nil } func (w *workerPool) Start() error { + if w.Status() != StatusCreated { + return errors.New("worker has been running or already stopped") + } w.initWorker() return nil } func (w *workerPool) Shutdown() error { + if w.Status() != StatusRunning { + return errors.New("worker is not running") + } w.stopWorker() return nil } -func (w *workerPool) GetJobTimeout() int { +func (w *workerPool) Status() Status { + w.statusLock.RLock() + defer w.statusLock.RUnlock() + return w.status +} + +func (w *workerPool) setStatus(status Status) { + w.statusLock.Lock() + defer w.statusLock.Unlock() + w.status = status +} + +func (w *workerPool) GetJobTimeout() time.Duration { return w.jobTimeout } -func (w *workerPool) GetShutdownTimeout() int { +func (w *workerPool) GetShutdownTimeout() time.Duration { return w.shutdownTimeout } func (w *workerPool) Push(job Job) error { + if w.Status() != StatusRunning { + return errors.New("worker is not running") + } + if job == nil { + return errors.New("job is nil") + } w.jobQueue <- job return nil } func (w *workerPool) PushAndWait(job Job) error { - //TODO implement me - panic("implement me") + if w.Status() != StatusRunning { + return errors.New("worker is not running") + } + if job == nil { + return errors.New("job is nil") + } + w.jobQueue <- job + <-job.Done() + return nil } func (w *workerPool) initWorker() { @@ -83,7 +121,8 @@ func (w *workerPool) initWorker() { defer w.workerCancellationLock.Unlock() for i := 0; i < w.workerSize; i++ { - cancelChan := make(chan IsCanceled) + w.wgShutdown.Add(1) + cancelChan := make(chan ChanSignal) cancelFunc := func() { close(cancelChan) } @@ -92,23 +131,43 @@ func (w *workerPool) initWorker() { w.workerCancellationFunc = append(w.workerCancellationFunc, cancelFunc) } + w.setStatus(StatusRunning) } func (w *workerPool) stopWorker() { w.workerCancellationLock.Lock() defer w.workerCancellationLock.Unlock() + w.setStatus(StatusStopped) + for _, f := range w.workerCancellationFunc { f() } w.workerCancellationFunc = nil + + close(w.jobQueue) + + w.wgShutdown.Wait() } -func (w *workerPool) jobDispatcher(cancelChan chan IsCanceled) { +func (w *workerPool) jobDispatcher(cancelChan chan ChanSignal) { + defer func() { + for job := range w.jobQueue { + ctx, cancel := context.WithTimeout(context.Background(), 0) + job.Do(ctx) + cancel() + } + w.wgShutdown.Done() + }() + for { select { case job := <-w.jobQueue: - done := make(chan struct{}) + if job == nil { + continue + } + + done := make(chan ChanSignal) go func() { var ( @@ -116,7 +175,7 @@ func (w *workerPool) jobDispatcher(cancelChan chan IsCanceled) { cancelFunc context.CancelFunc ) if w.jobTimeout > 0 { - ctx, cancelFunc = context.WithTimeout(ctx, time.Duration(w.jobTimeout)*time.Second) + ctx, cancelFunc = context.WithTimeout(ctx, w.jobTimeout) } job.Do(ctx) @@ -132,11 +191,17 @@ func (w *workerPool) jobDispatcher(cancelChan chan IsCanceled) { // the job is done case <-cancelChan: // on graceful shutdown - ctx, cancelFunc := context.WithTimeout(context.Background(), time.Duration(w.shutdownTimeout)*time.Second) - - job.Cancel(ctx) - - cancelFunc() + if w.shutdownTimeout > 0 { + select { + case <-done: + case <-time.After(w.shutdownTimeout): + ctx, cancel := context.WithTimeout(context.Background(), 0) + job.Cancel(ctx) + cancel() + } + } else { + <-done + } return } case <-cancelChan: diff --git a/worker_pool_test.go b/worker_pool_test.go new file mode 100644 index 0000000..55d24de --- /dev/null +++ b/worker_pool_test.go @@ -0,0 +1,63 @@ +package workers + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestWorkerPoolMainFlow(t *testing.T) { + worker, err := NewWorkerPool(WorkerPoolConfig{}) + assert.NoError(t, err) + + value := 0 + + job := NewJobSimple(func(isCanceled chan ChanSignal, requestData interface{}, result chan *Result) { + val, ok := requestData.(*int) + assert.True(t, ok) + *val++ + }, &value, nil) + + err = worker.Push(job) + assert.Error(t, err) + + err = worker.PushAndWait(job) + assert.Error(t, err) + + err = worker.Shutdown() + assert.Error(t, err) + + assert.Equal(t, StatusCreated, worker.Status()) + + // start the worker + err = worker.Start() + assert.NoError(t, err) + + assert.Equal(t, StatusRunning, worker.Status()) + + err = worker.Start() + assert.Error(t, err) + + err = worker.Push(job) + assert.NoError(t, err) + + <-job.Done() + + err = worker.PushAndWait(job) + assert.NoError(t, err) + + // the job will not be executed twice since it's already running or done + assert.Equal(t, 1, value) + + // shutdown the worker + err = worker.Shutdown() + assert.NoError(t, err) + + assert.Equal(t, StatusStopped, worker.Status()) + + err = worker.Start() + assert.Error(t, err) + + err = worker.Shutdown() + assert.Error(t, err) +}