From 597216c47fe09af76ac908c242353df5b5b8b0c4 Mon Sep 17 00:00:00 2001 From: Brandur Date: Sun, 8 Sep 2024 18:44:34 -0700 Subject: [PATCH] Add middleware system for jobs Here, experiment with a middleware-like system that adds middleware functions to job lifecycles, which results in them being invoked during specific phases of a job like as it's being inserted or worked. The most obvious unlock for this is telemetry (e.g. logging, metrics), but it also acts as a building block for features like encrypted jobs. Co-authored-by: Blake Gentry --- client.go | 61 +++++++++++++------ client_test.go | 4 +- internal/dbunique/db_unique.go | 5 +- internal/dbunique/db_unique_test.go | 3 +- internal/maintenance/periodic_job_enqueuer.go | 6 +- .../maintenance/periodic_job_enqueuer_test.go | 10 +-- internal/maintenance/queue_maintainer_test.go | 4 +- job_executor.go | 27 ++++++-- middleware_defaults.go | 21 +++++++ middleware_defaults_test.go | 5 ++ periodic_job.go | 3 +- producer.go | 6 +- producer_test.go | 4 +- rivertype/river_type.go | 42 +++++++++++++ 14 files changed, 154 insertions(+), 47 deletions(-) create mode 100644 middleware_defaults.go create mode 100644 middleware_defaults_test.go diff --git a/client.go b/client.go index 7fd6506b..d0f6f848 100644 --- a/client.go +++ b/client.go @@ -139,6 +139,10 @@ type Config struct { // deployments. JobCleanerTimeout time.Duration + // JobMiddleware are optional functions that can be called around different + // parts of each job's lifecycle. + JobMiddleware []rivertype.JobMiddleware + // JobTimeout is the maximum amount of time a job is allowed to run before its // context is cancelled. A timeout of zero means JobTimeoutDefault will be // used, whereas a value of -1 means the job's context will not be cancelled @@ -467,6 +471,7 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client FetchCooldown: valutil.ValOrDefault(config.FetchCooldown, FetchCooldownDefault), FetchPollInterval: valutil.ValOrDefault(config.FetchPollInterval, FetchPollIntervalDefault), ID: valutil.ValOrDefaultFunc(config.ID, func() string { return defaultClientID(time.Now().UTC()) }), + JobMiddleware: config.JobMiddleware, JobTimeout: valutil.ValOrDefault(config.JobTimeout, JobTimeoutDefault), Logger: logger, MaxAttempts: valutil.ValOrDefault(config.MaxAttempts, MaxAttemptsDefault), @@ -1165,7 +1170,7 @@ func (c *Client[TTx]) ID() string { return c.config.ID } -func insertParamsFromConfigArgsAndOptions(archetype *baseservice.Archetype, config *Config, args JobArgs, insertOpts *InsertOpts) (*riverdriver.JobInsertFastParams, error) { +func insertParamsFromConfigArgsAndOptions(archetype *baseservice.Archetype, config *Config, args JobArgs, insertOpts *InsertOpts) (*rivertype.JobInsertParams, error) { encodedArgs, err := json.Marshal(args) if err != nil { return nil, fmt.Errorf("error marshaling args to JSON: %w", err) @@ -1227,13 +1232,13 @@ func insertParamsFromConfigArgsAndOptions(archetype *baseservice.Archetype, conf metadata = []byte("{}") } - insertParams := &riverdriver.JobInsertFastParams{ + insertParams := &rivertype.JobInsertParams{ Args: args, CreatedAt: createdAt, - EncodedArgs: json.RawMessage(encodedArgs), + EncodedArgs: encodedArgs, Kind: args.Kind(), MaxAttempts: maxAttempts, - Metadata: json.RawMessage(metadata), + Metadata: metadata, Priority: priority, Queue: queue, State: rivertype.JobStateAvailable, @@ -1436,39 +1441,56 @@ func (c *Client[TTx]) insertMany(ctx context.Context, tx riverdriver.ExecutorTx, func (c *Client[TTx]) insertManyShared( ctx context.Context, tx riverdriver.ExecutorTx, - params []InsertManyParams, + rawParams []InsertManyParams, execute func(context.Context, []*riverdriver.JobInsertFastParams) ([]*rivertype.JobInsertResult, error), ) ([]*rivertype.JobInsertResult, error) { - insertParams, err := c.insertManyParams(params) + insertParams, err := c.insertManyParams(rawParams) if err != nil { return nil, err } - inserted, err := execute(ctx, insertParams) - if err != nil { - return inserted, err - } + doInner := func(ctx context.Context) ([]*rivertype.JobInsertResult, error) { + finalInsertParams := sliceutil.Map(insertParams, func(params *rivertype.JobInsertParams) *riverdriver.JobInsertFastParams { + return (*riverdriver.JobInsertFastParams)(params) + }) + results, err := execute(ctx, finalInsertParams) + if err != nil { + return results, err + } - queues := make([]string, 0, 10) - for _, params := range insertParams { - if params.State == rivertype.JobStateAvailable { - queues = append(queues, params.Queue) + queues := make([]string, 0, 10) + for _, params := range insertParams { + if params.State == rivertype.JobStateAvailable { + queues = append(queues, params.Queue) + } } + if err := c.maybeNotifyInsertForQueues(ctx, tx, queues); err != nil { + return nil, err + } + return results, nil } - if err := c.maybeNotifyInsertForQueues(ctx, tx, queues); err != nil { - return nil, err + + if len(c.config.JobMiddleware) > 0 { + // Wrap middlewares in reverse order so the one defined first is wrapped + // as the outermost function and is first to receive the operation. + for i := len(c.config.JobMiddleware) - 1; i >= 0; i-- { + doInner = func(ctx context.Context) ([]*rivertype.JobInsertResult, error) { + return c.config.JobMiddleware[i].InsertMany(ctx, insertParams, doInner) + } + } } - return inserted, nil + + return doInner(ctx) } // Validates input parameters for a batch insert operation and generates a set // of batch insert parameters. -func (c *Client[TTx]) insertManyParams(params []InsertManyParams) ([]*riverdriver.JobInsertFastParams, error) { +func (c *Client[TTx]) insertManyParams(params []InsertManyParams) ([]*rivertype.JobInsertParams, error) { if len(params) < 1 { return nil, errors.New("no jobs to insert") } - insertParams := make([]*riverdriver.JobInsertFastParams, len(params)) + insertParams := make([]*rivertype.JobInsertParams, len(params)) for i, param := range params { if err := c.validateJobArgs(param.Args); err != nil { return nil, err @@ -1665,6 +1687,7 @@ func (c *Client[TTx]) addProducer(queueName string, queueConfig QueueConfig) *pr ErrorHandler: c.config.ErrorHandler, FetchCooldown: c.config.FetchCooldown, FetchPollInterval: c.config.FetchPollInterval, + JobMiddleware: c.config.JobMiddleware, JobTimeout: c.config.JobTimeout, MaxWorkers: queueConfig.MaxWorkers, Notifier: c.notifier, diff --git a/client_test.go b/client_test.go index 8ee79f01..868a8ee7 100644 --- a/client_test.go +++ b/client_test.go @@ -2968,7 +2968,7 @@ func Test_Client_ErrorHandler(t *testing.T) { // unknown job. insertParams, err := insertParamsFromConfigArgsAndOptions(&client.baseService.Archetype, config, unregisteredJobArgs{}, nil) require.NoError(t, err) - _, err = client.driver.GetExecutor().JobInsertFastMany(ctx, []*riverdriver.JobInsertFastParams{insertParams}) + _, err = client.driver.GetExecutor().JobInsertFastMany(ctx, []*riverdriver.JobInsertFastParams{(*riverdriver.JobInsertFastParams)(insertParams)}) require.NoError(t, err) riversharedtest.WaitOrTimeout(t, bundle.SubscribeChan) @@ -4570,7 +4570,7 @@ func Test_Client_UnknownJobKindErrorsTheJob(t *testing.T) { insertParams, err := insertParamsFromConfigArgsAndOptions(&client.baseService.Archetype, config, unregisteredJobArgs{}, nil) require.NoError(err) - insertedResults, err := client.driver.GetExecutor().JobInsertFastMany(ctx, []*riverdriver.JobInsertFastParams{insertParams}) + insertedResults, err := client.driver.GetExecutor().JobInsertFastMany(ctx, []*riverdriver.JobInsertFastParams{(*riverdriver.JobInsertFastParams)(insertParams)}) require.NoError(err) insertedResult := insertedResults[0] diff --git a/internal/dbunique/db_unique.go b/internal/dbunique/db_unique.go index 60256acf..204de826 100644 --- a/internal/dbunique/db_unique.go +++ b/internal/dbunique/db_unique.go @@ -9,7 +9,6 @@ import ( "github.com/tidwall/gjson" "github.com/tidwall/sjson" - "github.com/riverqueue/river/riverdriver" "github.com/riverqueue/river/rivershared/baseservice" "github.com/riverqueue/river/rivershared/util/sliceutil" "github.com/riverqueue/river/rivertype" @@ -64,7 +63,7 @@ func (o *UniqueOpts) StateBitmask() byte { return UniqueStatesToBitmask(states) } -func UniqueKey(timeGen baseservice.TimeGenerator, uniqueOpts *UniqueOpts, params *riverdriver.JobInsertFastParams) ([]byte, error) { +func UniqueKey(timeGen baseservice.TimeGenerator, uniqueOpts *UniqueOpts, params *rivertype.JobInsertParams) ([]byte, error) { uniqueKeyString, err := buildUniqueKeyString(timeGen, uniqueOpts, params) if err != nil { return nil, err @@ -76,7 +75,7 @@ func UniqueKey(timeGen baseservice.TimeGenerator, uniqueOpts *UniqueOpts, params // Builds a unique key made up of the unique options in place. The key is hashed // to become a value for `unique_key` in the fast insertion path, or hashed and // used for an advisory lock on the slow insertion path. -func buildUniqueKeyString(timeGen baseservice.TimeGenerator, uniqueOpts *UniqueOpts, params *riverdriver.JobInsertFastParams) (string, error) { +func buildUniqueKeyString(timeGen baseservice.TimeGenerator, uniqueOpts *UniqueOpts, params *rivertype.JobInsertParams) (string, error) { var sb strings.Builder if !uniqueOpts.ExcludeKind { diff --git a/internal/dbunique/db_unique_test.go b/internal/dbunique/db_unique_test.go index bfa8336e..e2ad3037 100644 --- a/internal/dbunique/db_unique_test.go +++ b/internal/dbunique/db_unique_test.go @@ -9,7 +9,6 @@ import ( "github.com/stretchr/testify/require" - "github.com/riverqueue/river/riverdriver" "github.com/riverqueue/river/rivershared/riversharedtest" "github.com/riverqueue/river/rivertype" ) @@ -229,7 +228,7 @@ func TestUniqueKey(t *testing.T) { states = tt.uniqueOpts.ByState } - jobParams := &riverdriver.JobInsertFastParams{ + jobParams := &rivertype.JobInsertParams{ Args: args, CreatedAt: &now, EncodedArgs: encodedArgs, diff --git a/internal/maintenance/periodic_job_enqueuer.go b/internal/maintenance/periodic_job_enqueuer.go index fb1c435f..6ccf11c6 100644 --- a/internal/maintenance/periodic_job_enqueuer.go +++ b/internal/maintenance/periodic_job_enqueuer.go @@ -38,7 +38,7 @@ func (ts *PeriodicJobEnqueuerTestSignals) Init() { // river.PeriodicJobArgs, but needs a separate type because the enqueuer is in a // subpackage. type PeriodicJob struct { - ConstructorFunc func() (*riverdriver.JobInsertFastParams, error) + ConstructorFunc func() (*rivertype.JobInsertParams, error) RunOnStart bool ScheduleFunc func(time.Time) time.Time @@ -373,7 +373,7 @@ func (s *PeriodicJobEnqueuer) insertBatch(ctx context.Context, insertParamsMany s.TestSignals.InsertedJobs.Signal(struct{}{}) } -func (s *PeriodicJobEnqueuer) insertParamsFromConstructor(ctx context.Context, constructorFunc func() (*riverdriver.JobInsertFastParams, error), scheduledAt time.Time) (*riverdriver.JobInsertFastParams, bool) { +func (s *PeriodicJobEnqueuer) insertParamsFromConstructor(ctx context.Context, constructorFunc func() (*rivertype.JobInsertParams, error), scheduledAt time.Time) (*riverdriver.JobInsertFastParams, bool) { insertParams, err := constructorFunc() if err != nil { if errors.Is(err, ErrNoJobToInsert) { @@ -389,7 +389,7 @@ func (s *PeriodicJobEnqueuer) insertParamsFromConstructor(ctx context.Context, c insertParams.ScheduledAt = &scheduledAt } - return insertParams, true + return (*riverdriver.JobInsertFastParams)(insertParams), true } const periodicJobEnqueuerVeryLongDuration = 24 * time.Hour diff --git a/internal/maintenance/periodic_job_enqueuer_test.go b/internal/maintenance/periodic_job_enqueuer_test.go index 48677407..189ff9fc 100644 --- a/internal/maintenance/periodic_job_enqueuer_test.go +++ b/internal/maintenance/periodic_job_enqueuer_test.go @@ -40,9 +40,9 @@ func TestPeriodicJobEnqueuer(t *testing.T) { stubSvc := &riversharedtest.TimeStub{} stubSvc.StubNowUTC(time.Now().UTC()) - jobConstructorWithQueueFunc := func(name string, unique bool, queue string) func() (*riverdriver.JobInsertFastParams, error) { - return func() (*riverdriver.JobInsertFastParams, error) { - params := &riverdriver.JobInsertFastParams{ + jobConstructorWithQueueFunc := func(name string, unique bool, queue string) func() (*rivertype.JobInsertParams, error) { + return func() (*rivertype.JobInsertParams, error) { + params := &rivertype.JobInsertParams{ Args: noOpArgs{}, EncodedArgs: []byte("{}"), Kind: name, @@ -66,7 +66,7 @@ func TestPeriodicJobEnqueuer(t *testing.T) { } } - jobConstructorFunc := func(name string, unique bool) func() (*riverdriver.JobInsertFastParams, error) { + jobConstructorFunc := func(name string, unique bool) func() (*rivertype.JobInsertParams, error) { return jobConstructorWithQueueFunc(name, unique, rivercommon.QueueDefault) } @@ -256,7 +256,7 @@ func TestPeriodicJobEnqueuer(t *testing.T) { svc.AddMany([]*PeriodicJob{ // skip this insert when it returns nil: - {ScheduleFunc: periodicIntervalSchedule(time.Second), ConstructorFunc: func() (*riverdriver.JobInsertFastParams, error) { + {ScheduleFunc: periodicIntervalSchedule(time.Second), ConstructorFunc: func() (*rivertype.JobInsertParams, error) { return nil, ErrNoJobToInsert }, RunOnStart: true}, }) diff --git a/internal/maintenance/queue_maintainer_test.go b/internal/maintenance/queue_maintainer_test.go index a68e91a8..5da37def 100644 --- a/internal/maintenance/queue_maintainer_test.go +++ b/internal/maintenance/queue_maintainer_test.go @@ -10,13 +10,13 @@ import ( "github.com/riverqueue/river/internal/riverinternaltest" "github.com/riverqueue/river/internal/riverinternaltest/sharedtx" - "github.com/riverqueue/river/riverdriver" "github.com/riverqueue/river/riverdriver/riverpgxv5" "github.com/riverqueue/river/rivershared/baseservice" "github.com/riverqueue/river/rivershared/riversharedtest" "github.com/riverqueue/river/rivershared/startstop" "github.com/riverqueue/river/rivershared/startstoptest" "github.com/riverqueue/river/rivershared/testsignal" + "github.com/riverqueue/river/rivertype" ) type testService struct { @@ -107,7 +107,7 @@ func TestQueueMaintainer(t *testing.T) { NewPeriodicJobEnqueuer(archetype, &PeriodicJobEnqueuerConfig{ PeriodicJobs: []*PeriodicJob{ { - ConstructorFunc: func() (*riverdriver.JobInsertFastParams, error) { + ConstructorFunc: func() (*rivertype.JobInsertParams, error) { return nil, ErrNoJobToInsert }, ScheduleFunc: cron.Every(15 * time.Minute).Next, diff --git a/job_executor.go b/job_executor.go index 61669e9d..3ca83957 100644 --- a/job_executor.go +++ b/job_executor.go @@ -128,6 +128,7 @@ type jobExecutor struct { ErrorHandler ErrorHandler InformProducerDoneFunc func(jobRow *rivertype.JobRow) JobRow *rivertype.JobRow + JobMiddleware []rivertype.JobMiddleware SchedulerInterval time.Duration WorkUnit workunit.WorkUnit @@ -190,11 +191,11 @@ func (e *jobExecutor) execute(ctx context.Context) (res *jobExecutorResult) { return &jobExecutorResult{Err: &UnknownJobKindError{Kind: e.JobRow.Kind}} } - if err := e.WorkUnit.UnmarshalJob(); err != nil { - return &jobExecutorResult{Err: err} - } + doInner := func(ctx context.Context) error { + if err := e.WorkUnit.UnmarshalJob(); err != nil { + return err + } - { jobTimeout := e.WorkUnit.Timeout() if jobTimeout == 0 { jobTimeout = e.ClientJobTimeout @@ -207,8 +208,24 @@ func (e *jobExecutor) execute(ctx context.Context) (res *jobExecutorResult) { defer cancel() } - return &jobExecutorResult{Err: e.WorkUnit.Work(ctx)} + if err := e.WorkUnit.Work(ctx); err != nil { + return err + } + + return nil } + + if len(e.JobMiddleware) > 0 { + // Wrap middlewares in reverse order so the one defined first is wrapped + // as the outermost function and is first to receive the operation. + for i := len(e.JobMiddleware) - 1; i >= 0; i-- { + doInner = func(ctx context.Context) error { + return e.JobMiddleware[i].Work(ctx, e.JobRow, doInner) + } + } + } + + return &jobExecutorResult{Err: doInner(ctx)} } func (e *jobExecutor) invokeErrorHandler(ctx context.Context, res *jobExecutorResult) bool { diff --git a/middleware_defaults.go b/middleware_defaults.go new file mode 100644 index 00000000..c0a60a54 --- /dev/null +++ b/middleware_defaults.go @@ -0,0 +1,21 @@ +package river + +import ( + "context" + + "github.com/riverqueue/river/rivertype" +) + +// JobMiddlewareDefaults is an embeddable struct that provides default +// implementations for the rivertype.JobMiddleware. Use of this struct is +// recommended in case rivertype.JobMiddleware is expanded in the future so that +// existing code isn't unexpectedly broken during an upgrade. +type JobMiddlewareDefaults struct{} + +func (l *JobMiddlewareDefaults) InsertMany(ctx context.Context, manyParams []*rivertype.JobInsertParams, doInner func(ctx context.Context) ([]*rivertype.JobInsertResult, error)) ([]*rivertype.JobInsertResult, error) { + return doInner(ctx) +} + +func (l *JobMiddlewareDefaults) Work(ctx context.Context, job *rivertype.JobRow, doInner func(ctx context.Context) error) error { + return doInner(ctx) +} diff --git a/middleware_defaults_test.go b/middleware_defaults_test.go new file mode 100644 index 00000000..323e0415 --- /dev/null +++ b/middleware_defaults_test.go @@ -0,0 +1,5 @@ +package river + +import "github.com/riverqueue/river/rivertype" + +var _ rivertype.JobMiddleware = &JobMiddlewareDefaults{} diff --git a/periodic_job.go b/periodic_job.go index c3db82f2..55d12f4e 100644 --- a/periodic_job.go +++ b/periodic_job.go @@ -4,7 +4,6 @@ import ( "time" "github.com/riverqueue/river/internal/maintenance" - "github.com/riverqueue/river/riverdriver" "github.com/riverqueue/river/rivershared/util/sliceutil" "github.com/riverqueue/river/rivertype" ) @@ -180,7 +179,7 @@ func (b *PeriodicJobBundle) toInternal(periodicJob *PeriodicJob) *maintenance.Pe opts = periodicJob.opts } return &maintenance.PeriodicJob{ - ConstructorFunc: func() (*riverdriver.JobInsertFastParams, error) { + ConstructorFunc: func() (*rivertype.JobInsertParams, error) { args, options := periodicJob.constructorFunc() if args == nil { return nil, maintenance.ErrNoJobToInsert diff --git a/producer.go b/producer.go index 0a9d8e53..8d0d0e6c 100644 --- a/producer.go +++ b/producer.go @@ -63,8 +63,9 @@ type producerConfig struct { // LISTEN/NOTIFY, but this provides a fallback. FetchPollInterval time.Duration - JobTimeout time.Duration - MaxWorkers int + JobMiddleware []rivertype.JobMiddleware + JobTimeout time.Duration + MaxWorkers int // Notifier is a notifier for subscribing to new job inserts and job // control. If nil, the producer will operate in poll-only mode. @@ -579,6 +580,7 @@ func (p *producer) startNewExecutors(workCtx context.Context, jobs []*rivertype. Completer: p.completer, ErrorHandler: p.errorHandler, InformProducerDoneFunc: p.handleWorkerDone, + JobMiddleware: p.config.JobMiddleware, JobRow: job, SchedulerInterval: p.config.SchedulerInterval, WorkUnit: workUnit, diff --git a/producer_test.go b/producer_test.go index 7c79cb00..ad7f9d8d 100644 --- a/producer_test.go +++ b/producer_test.go @@ -105,7 +105,7 @@ func Test_Producer_CanSafelyCompleteJobsWhileFetchingNewOnes(t *testing.T) { insertParams, err := insertParamsFromConfigArgsAndOptions(archetype, config, WithJobNumArgs{JobNum: i}, nil) require.NoError(err) - params[i] = insertParams + params[i] = (*riverdriver.JobInsertFastParams)(insertParams) } ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) @@ -295,7 +295,7 @@ func testProducer(t *testing.T, makeProducer func(ctx context.Context, t *testin insertParams.ScheduledAt = &bundle.timeBeforeStart } - _, err = bundle.exec.JobInsertFastMany(ctx, []*riverdriver.JobInsertFastParams{insertParams}) + _, err = bundle.exec.JobInsertFastMany(ctx, []*riverdriver.JobInsertFastParams{(*riverdriver.JobInsertFastParams)(insertParams)}) require.NoError(t, err) } diff --git a/rivertype/river_type.go b/rivertype/river_type.go index 1998577a..8d545ebb 100644 --- a/rivertype/river_type.go +++ b/rivertype/river_type.go @@ -4,6 +4,7 @@ package rivertype import ( + "context" "errors" "time" ) @@ -229,6 +230,47 @@ type AttemptError struct { Trace string `json:"trace"` } +type JobInsertParams struct { + Args JobArgs + CreatedAt *time.Time + EncodedArgs []byte + Kind string + MaxAttempts int + Metadata []byte + Priority int + Queue string + ScheduledAt *time.Time + State JobState + Tags []string + UniqueKey []byte + UniqueStates byte +} + +// JobMiddleware provides an interface for middleware that integrations can use +// to encapsulate common logic around various phases of a job's lifecycle. +// +// Implementations should embed river.JobMiddlewareDefaults to inherit default +// implementations for phases where no custom code is needed, and for forward +// compatibility in case new functions are added to this interface. +type JobMiddleware interface { + // InsertMany is invoked around a batch insert operation. Implementations + // must always include a call to doInner to call down the middleware stack + // and perfom the batch insertion, and may run custom code before and after. + // + // Returning an error from this function will fail the overarching insert + // operation, even if the inner insertion originally succeeded. + InsertMany(ctx context.Context, manyParams []*JobInsertParams, doInner func(ctx context.Context) ([]*JobInsertResult, error)) ([]*JobInsertResult, error) + + // Work is invoked around a job's JSON args being unmarshaled and the job + // worked. Implementations must always include a call to doInner to call + // down the middleware stack and perfom the batch insertion, and may run + // custom code before and after. + // + // Returning an error from this function will fail the overarching work + // operation, even if the inner work originally succeeded. + Work(ctx context.Context, job *JobRow, doInner func(ctx context.Context) error) error +} + // PeriodicJobHandle is a reference to a dynamically added periodic job // (returned by the use of `Client.PeriodicJobs().Add()`) which can be used to // subsequently remove the periodic job with `Remove()`.