diff --git a/CHANGELOG.md b/CHANGELOG.md index e5117c55..4d9f35d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ⚠️ Version 0.13.0 removes the original advisory lock based unique jobs implementation that was deprecated in v0.12.0. See details in the note below or the v0.12.0 release notes. +### Added + +- A middleware system was added for job insertion and execution, providing the ability to extract shared functionality across workers. Both `JobInsertMiddleware` and `WorkerMiddleware` can be configured globally on the `Client`, and `WorkerMiddleware` can also be added on a per-worker basis using the new `Middleware` method on `Worker[T]`. Middleware can be useful for logging, telemetry, or for building higher level abstractions on top of base River functionality. + + Despite the interface expansion, users should not encounter any breakage if they're embedding the `WorkerDefaults` type in their workers as recommended. [PR #632](https://github.com/riverqueue/river/pull/632). + ### Changed - **Breaking change:** The advisory lock unique jobs implementation which was deprecated in v0.12.0 has been removed. Users of that feature should first upgrade to v0.12.1 to ensure they don't see any warning logs about using the deprecated advisory lock uniqueness. The new, faster unique implementation will be used automatically as long as the `UniqueOpts.ByState` list hasn't been customized to remove [required states](https://riverqueue.com/docs/unique-jobs#unique-by-state) (`pending`, `scheduled`, `available`, and `running`). As of this release, customizing `ByState` without these required states returns an error. [PR #614](https://github.com/riverqueue/river/pull/614). diff --git a/client.go b/client.go index 7fd6506b..6e4e3a0c 100644 --- a/client.go +++ b/client.go @@ -139,6 +139,10 @@ type Config struct { // deployments. JobCleanerTimeout time.Duration + // JobInsertMiddleware are optional functions that can be called around job + // insertion. + JobInsertMiddleware []rivertype.JobInsertMiddleware + // JobTimeout is the maximum amount of time a job is allowed to run before its // context is cancelled. A timeout of zero means JobTimeoutDefault will be // used, whereas a value of -1 means the job's context will not be cancelled @@ -235,6 +239,10 @@ type Config struct { // (i.e. That it wasn't forgotten by accident.) Workers *Workers + // WorkerMiddleware are optional functions that can be called around + // all job executions. + WorkerMiddleware []rivertype.WorkerMiddleware + // Scheduler run interval. Shared between the scheduler and producer/job // executors, but not currently exposed for configuration. schedulerInterval time.Duration @@ -467,6 +475,7 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client FetchCooldown: valutil.ValOrDefault(config.FetchCooldown, FetchCooldownDefault), FetchPollInterval: valutil.ValOrDefault(config.FetchPollInterval, FetchPollIntervalDefault), ID: valutil.ValOrDefaultFunc(config.ID, func() string { return defaultClientID(time.Now().UTC()) }), + JobInsertMiddleware: config.JobInsertMiddleware, JobTimeout: valutil.ValOrDefault(config.JobTimeout, JobTimeoutDefault), Logger: logger, MaxAttempts: valutil.ValOrDefault(config.MaxAttempts, MaxAttemptsDefault), @@ -478,6 +487,7 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client RetryPolicy: retryPolicy, TestOnly: config.TestOnly, Workers: config.Workers, + WorkerMiddleware: config.WorkerMiddleware, schedulerInterval: valutil.ValOrDefault(config.schedulerInterval, maintenance.JobSchedulerIntervalDefault), time: config.time, } @@ -1165,7 +1175,7 @@ func (c *Client[TTx]) ID() string { return c.config.ID } -func insertParamsFromConfigArgsAndOptions(archetype *baseservice.Archetype, config *Config, args JobArgs, insertOpts *InsertOpts) (*riverdriver.JobInsertFastParams, error) { +func insertParamsFromConfigArgsAndOptions(archetype *baseservice.Archetype, config *Config, args JobArgs, insertOpts *InsertOpts) (*rivertype.JobInsertParams, error) { encodedArgs, err := json.Marshal(args) if err != nil { return nil, fmt.Errorf("error marshaling args to JSON: %w", err) @@ -1227,13 +1237,13 @@ func insertParamsFromConfigArgsAndOptions(archetype *baseservice.Archetype, conf metadata = []byte("{}") } - insertParams := &riverdriver.JobInsertFastParams{ + insertParams := &rivertype.JobInsertParams{ Args: args, CreatedAt: createdAt, - EncodedArgs: json.RawMessage(encodedArgs), + EncodedArgs: encodedArgs, Kind: args.Kind(), MaxAttempts: maxAttempts, - Metadata: json.RawMessage(metadata), + Metadata: metadata, Priority: priority, Queue: queue, State: rivertype.JobStateAvailable, @@ -1436,39 +1446,58 @@ func (c *Client[TTx]) insertMany(ctx context.Context, tx riverdriver.ExecutorTx, func (c *Client[TTx]) insertManyShared( ctx context.Context, tx riverdriver.ExecutorTx, - params []InsertManyParams, + rawParams []InsertManyParams, execute func(context.Context, []*riverdriver.JobInsertFastParams) ([]*rivertype.JobInsertResult, error), ) ([]*rivertype.JobInsertResult, error) { - insertParams, err := c.insertManyParams(params) + insertParams, err := c.insertManyParams(rawParams) if err != nil { return nil, err } - inserted, err := execute(ctx, insertParams) - if err != nil { - return inserted, err - } + doInner := func(ctx context.Context) ([]*rivertype.JobInsertResult, error) { + finalInsertParams := sliceutil.Map(insertParams, func(params *rivertype.JobInsertParams) *riverdriver.JobInsertFastParams { + return (*riverdriver.JobInsertFastParams)(params) + }) + results, err := execute(ctx, finalInsertParams) + if err != nil { + return results, err + } - queues := make([]string, 0, 10) - for _, params := range insertParams { - if params.State == rivertype.JobStateAvailable { - queues = append(queues, params.Queue) + queues := make([]string, 0, 10) + for _, params := range insertParams { + if params.State == rivertype.JobStateAvailable { + queues = append(queues, params.Queue) + } } + if err := c.maybeNotifyInsertForQueues(ctx, tx, queues); err != nil { + return nil, err + } + return results, nil } - if err := c.maybeNotifyInsertForQueues(ctx, tx, queues); err != nil { - return nil, err + + if len(c.config.JobInsertMiddleware) > 0 { + // Wrap middlewares in reverse order so the one defined first is wrapped + // as the outermost function and is first to receive the operation. + for i := len(c.config.JobInsertMiddleware) - 1; i >= 0; i-- { + middlewareItem := c.config.JobInsertMiddleware[i] // capture the current middleware item + previousDoInner := doInner // Capture the current doInner function + doInner = func(ctx context.Context) ([]*rivertype.JobInsertResult, error) { + return middlewareItem.InsertMany(ctx, insertParams, previousDoInner) + } + } } - return inserted, nil + + return doInner(ctx) } // Validates input parameters for a batch insert operation and generates a set // of batch insert parameters. -func (c *Client[TTx]) insertManyParams(params []InsertManyParams) ([]*riverdriver.JobInsertFastParams, error) { +func (c *Client[TTx]) insertManyParams(params []InsertManyParams) ([]*rivertype.JobInsertParams, error) { if len(params) < 1 { return nil, errors.New("no jobs to insert") } - insertParams := make([]*riverdriver.JobInsertFastParams, len(params)) + insertParams := make([]*rivertype.JobInsertParams, len(params)) for i, param := range params { if err := c.validateJobArgs(param.Args); err != nil { return nil, err @@ -1665,6 +1694,7 @@ func (c *Client[TTx]) addProducer(queueName string, queueConfig QueueConfig) *pr ErrorHandler: c.config.ErrorHandler, FetchCooldown: c.config.FetchCooldown, FetchPollInterval: c.config.FetchPollInterval, + GlobalMiddleware: c.config.WorkerMiddleware, JobTimeout: c.config.JobTimeout, MaxWorkers: queueConfig.MaxWorkers, Notifier: c.notifier, diff --git a/client_test.go b/client_test.go index cd9a9908..84207bed 100644 --- a/client_test.go +++ b/client_test.go @@ -20,6 +20,7 @@ import ( "github.com/jackc/pgx/v5/stdlib" "github.com/robfig/cron/v3" "github.com/stretchr/testify/require" + "github.com/tidwall/sjson" "github.com/riverqueue/river/internal/dbunique" "github.com/riverqueue/river/internal/maintenance" @@ -589,6 +590,90 @@ func Test_Client(t *testing.T) { require.Equal(t, `relation "river_job" does not exist`, pgErr.Message) }) + t.Run("WithWorkerMiddleware", func(t *testing.T) { + t.Parallel() + + _, bundle := setup(t) + middlewareCalled := false + + type privateKey string + + middleware := &overridableJobMiddleware{ + workFunc: func(ctx context.Context, job *rivertype.JobRow, doInner func(ctx context.Context) error) error { + ctx = context.WithValue(ctx, privateKey("middleware"), "called") + middlewareCalled = true + return doInner(ctx) + }, + } + bundle.config.WorkerMiddleware = []rivertype.WorkerMiddleware{middleware} + + AddWorker(bundle.config.Workers, WorkFunc(func(ctx context.Context, job *Job[callbackArgs]) error { + require.Equal(t, "called", ctx.Value(privateKey("middleware"))) + return nil + })) + + driver := riverpgxv5.New(bundle.dbPool) + client, err := NewClient(driver, bundle.config) + require.NoError(t, err) + + subscribeChan := subscribe(t, client) + startClient(ctx, t, client) + + result, err := client.Insert(ctx, callbackArgs{}, nil) + require.NoError(t, err) + + event := riversharedtest.WaitOrTimeout(t, subscribeChan) + require.Equal(t, EventKindJobCompleted, event.Kind) + require.Equal(t, result.Job.ID, event.Job.ID) + require.True(t, middlewareCalled) + }) + + t.Run("WithWorkerMiddlewareOnWorker", func(t *testing.T) { + t.Parallel() + + _, bundle := setup(t) + middlewareCalled := false + + type privateKey string + + worker := &workerWithMiddleware[callbackArgs]{ + workFunc: func(ctx context.Context, job *Job[callbackArgs]) error { + require.Equal(t, "called", ctx.Value(privateKey("middleware"))) + return nil + }, + middlewareFunc: func(job *Job[callbackArgs]) []rivertype.WorkerMiddleware { + require.Equal(t, "middleware_test", job.Args.Name, "JSON should be decoded before middleware is called") + + return []rivertype.WorkerMiddleware{ + &overridableJobMiddleware{ + workFunc: func(ctx context.Context, job *rivertype.JobRow, doInner func(ctx context.Context) error) error { + ctx = context.WithValue(ctx, privateKey("middleware"), "called") + middlewareCalled = true + return doInner(ctx) + }, + }, + } + }, + } + + AddWorker(bundle.config.Workers, worker) + + driver := riverpgxv5.New(bundle.dbPool) + client, err := NewClient(driver, bundle.config) + require.NoError(t, err) + + subscribeChan := subscribe(t, client) + startClient(ctx, t, client) + + result, err := client.Insert(ctx, callbackArgs{Name: "middleware_test"}, nil) + require.NoError(t, err) + + event := riversharedtest.WaitOrTimeout(t, subscribeChan) + require.Equal(t, EventKindJobCompleted, event.Kind) + require.Equal(t, result.Job.ID, event.Job.ID) + require.True(t, middlewareCalled) + }) + t.Run("PauseAndResumeSingleQueue", func(t *testing.T) { t.Parallel() @@ -835,6 +920,20 @@ func Test_Client(t *testing.T) { }) } +type workerWithMiddleware[T JobArgs] struct { + WorkerDefaults[T] + workFunc func(context.Context, *Job[T]) error + middlewareFunc func(*Job[T]) []rivertype.WorkerMiddleware +} + +func (w *workerWithMiddleware[T]) Work(ctx context.Context, job *Job[T]) error { + return w.workFunc(ctx, job) +} + +func (w *workerWithMiddleware[T]) Middleware(job *Job[T]) []rivertype.WorkerMiddleware { + return w.middlewareFunc(job) +} + func Test_Client_Stop(t *testing.T) { t.Parallel() @@ -2420,6 +2519,48 @@ func Test_Client_InsertManyTx(t *testing.T) { require.Len(t, results, 1) }) + t.Run("WithJobInsertMiddleware", func(t *testing.T) { + t.Parallel() + + _, bundle := setup(t) + config := newTestConfig(t, nil) + config.Queues = nil + + insertCalled := false + var innerResults []*rivertype.JobInsertResult + + middleware := &overridableJobMiddleware{ + insertManyFunc: func(ctx context.Context, manyParams []*rivertype.JobInsertParams, doInner func(ctx context.Context) ([]*rivertype.JobInsertResult, error)) ([]*rivertype.JobInsertResult, error) { + insertCalled = true + var err error + for _, params := range manyParams { + params.Metadata, err = sjson.SetBytes(params.Metadata, "middleware", "called") + require.NoError(t, err) + } + + results, err := doInner(ctx) + require.NoError(t, err) + innerResults = results + return results, nil + }, + } + + config.JobInsertMiddleware = []rivertype.JobInsertMiddleware{middleware} + driver := riverpgxv5.New(nil) + client, err := NewClient(driver, config) + require.NoError(t, err) + + results, err := client.InsertManyTx(ctx, bundle.tx, []InsertManyParams{{Args: noOpArgs{}}}) + require.NoError(t, err) + require.Len(t, results, 1) + + require.True(t, insertCalled) + require.Len(t, innerResults, 1) + require.Len(t, results, 1) + require.Equal(t, innerResults[0].Job.ID, results[0].Job.ID) + require.JSONEq(t, `{"middleware": "called"}`, string(results[0].Job.Metadata)) + }) + t.Run("WithUniqueOpts", func(t *testing.T) { t.Parallel() @@ -2998,7 +3139,7 @@ func Test_Client_ErrorHandler(t *testing.T) { // unknown job. insertParams, err := insertParamsFromConfigArgsAndOptions(&client.baseService.Archetype, config, unregisteredJobArgs{}, nil) require.NoError(t, err) - _, err = client.driver.GetExecutor().JobInsertFastMany(ctx, []*riverdriver.JobInsertFastParams{insertParams}) + _, err = client.driver.GetExecutor().JobInsertFastMany(ctx, []*riverdriver.JobInsertFastParams{(*riverdriver.JobInsertFastParams)(insertParams)}) require.NoError(t, err) riversharedtest.WaitOrTimeout(t, bundle.SubscribeChan) @@ -4600,7 +4741,7 @@ func Test_Client_UnknownJobKindErrorsTheJob(t *testing.T) { insertParams, err := insertParamsFromConfigArgsAndOptions(&client.baseService.Archetype, config, unregisteredJobArgs{}, nil) require.NoError(err) - insertedResults, err := client.driver.GetExecutor().JobInsertFastMany(ctx, []*riverdriver.JobInsertFastParams{insertParams}) + insertedResults, err := client.driver.GetExecutor().JobInsertFastMany(ctx, []*riverdriver.JobInsertFastParams{(*riverdriver.JobInsertFastParams)(insertParams)}) require.NoError(err) insertedResult := insertedResults[0] @@ -4770,6 +4911,14 @@ func Test_NewClient_Overrides(t *testing.T) { retryPolicy := &DefaultClientRetryPolicy{} + type noOpInsertMiddleware struct { + JobInsertMiddlewareDefaults + } + + type noOpWorkerMiddleware struct { + WorkerMiddlewareDefaults + } + client, err := NewClient(riverpgxv5.New(dbPool), &Config{ AdvisoryLockPrefix: 123_456, CancelledJobRetentionPeriod: 1 * time.Hour, @@ -4778,6 +4927,7 @@ func Test_NewClient_Overrides(t *testing.T) { ErrorHandler: errorHandler, FetchCooldown: 123 * time.Millisecond, FetchPollInterval: 124 * time.Millisecond, + JobInsertMiddleware: []rivertype.JobInsertMiddleware{&noOpInsertMiddleware{}}, JobTimeout: 125 * time.Millisecond, Logger: logger, MaxAttempts: 5, @@ -4785,6 +4935,7 @@ func Test_NewClient_Overrides(t *testing.T) { RetryPolicy: retryPolicy, TestOnly: true, // disables staggered start in maintenance services Workers: workers, + WorkerMiddleware: []rivertype.WorkerMiddleware{&noOpWorkerMiddleware{}}, }) require.NoError(t, err) @@ -4803,10 +4954,12 @@ func Test_NewClient_Overrides(t *testing.T) { require.Equal(t, errorHandler, client.config.ErrorHandler) require.Equal(t, 123*time.Millisecond, client.config.FetchCooldown) require.Equal(t, 124*time.Millisecond, client.config.FetchPollInterval) + require.Len(t, client.config.JobInsertMiddleware, 1) require.Equal(t, 125*time.Millisecond, client.config.JobTimeout) require.Equal(t, logger, client.baseService.Logger) require.Equal(t, 5, client.config.MaxAttempts) require.Equal(t, retryPolicy, client.config.RetryPolicy) + require.Len(t, client.config.WorkerMiddleware, 1) } func Test_NewClient_MissingParameters(t *testing.T) { diff --git a/internal/dbunique/db_unique.go b/internal/dbunique/db_unique.go index 60256acf..e4ccd010 100644 --- a/internal/dbunique/db_unique.go +++ b/internal/dbunique/db_unique.go @@ -9,7 +9,6 @@ import ( "github.com/tidwall/gjson" "github.com/tidwall/sjson" - "github.com/riverqueue/river/riverdriver" "github.com/riverqueue/river/rivershared/baseservice" "github.com/riverqueue/river/rivershared/util/sliceutil" "github.com/riverqueue/river/rivertype" @@ -64,7 +63,7 @@ func (o *UniqueOpts) StateBitmask() byte { return UniqueStatesToBitmask(states) } -func UniqueKey(timeGen baseservice.TimeGenerator, uniqueOpts *UniqueOpts, params *riverdriver.JobInsertFastParams) ([]byte, error) { +func UniqueKey(timeGen baseservice.TimeGenerator, uniqueOpts *UniqueOpts, params *rivertype.JobInsertParams) ([]byte, error) { uniqueKeyString, err := buildUniqueKeyString(timeGen, uniqueOpts, params) if err != nil { return nil, err @@ -74,9 +73,8 @@ func UniqueKey(timeGen baseservice.TimeGenerator, uniqueOpts *UniqueOpts, params } // Builds a unique key made up of the unique options in place. The key is hashed -// to become a value for `unique_key` in the fast insertion path, or hashed and -// used for an advisory lock on the slow insertion path. -func buildUniqueKeyString(timeGen baseservice.TimeGenerator, uniqueOpts *UniqueOpts, params *riverdriver.JobInsertFastParams) (string, error) { +// to become a value for `unique_key`. +func buildUniqueKeyString(timeGen baseservice.TimeGenerator, uniqueOpts *UniqueOpts, params *rivertype.JobInsertParams) (string, error) { var sb strings.Builder if !uniqueOpts.ExcludeKind { diff --git a/internal/dbunique/db_unique_test.go b/internal/dbunique/db_unique_test.go index bfa8336e..e2ad3037 100644 --- a/internal/dbunique/db_unique_test.go +++ b/internal/dbunique/db_unique_test.go @@ -9,7 +9,6 @@ import ( "github.com/stretchr/testify/require" - "github.com/riverqueue/river/riverdriver" "github.com/riverqueue/river/rivershared/riversharedtest" "github.com/riverqueue/river/rivertype" ) @@ -229,7 +228,7 @@ func TestUniqueKey(t *testing.T) { states = tt.uniqueOpts.ByState } - jobParams := &riverdriver.JobInsertFastParams{ + jobParams := &rivertype.JobInsertParams{ Args: args, CreatedAt: &now, EncodedArgs: encodedArgs, diff --git a/internal/maintenance/job_rescuer_test.go b/internal/maintenance/job_rescuer_test.go index ec938875..f23941af 100644 --- a/internal/maintenance/job_rescuer_test.go +++ b/internal/maintenance/job_rescuer_test.go @@ -38,10 +38,11 @@ type callbackWorkUnit struct { timeout time.Duration // defaults to 0, which signals default timeout } -func (w *callbackWorkUnit) NextRetry() time.Time { return time.Now().Add(30 * time.Second) } -func (w *callbackWorkUnit) Timeout() time.Duration { return w.timeout } -func (w *callbackWorkUnit) Work(ctx context.Context) error { return w.callback(ctx, w.jobRow) } -func (w *callbackWorkUnit) UnmarshalJob() error { return nil } +func (w *callbackWorkUnit) Middleware() []rivertype.WorkerMiddleware { return nil } +func (w *callbackWorkUnit) NextRetry() time.Time { return time.Now().Add(30 * time.Second) } +func (w *callbackWorkUnit) Timeout() time.Duration { return w.timeout } +func (w *callbackWorkUnit) Work(ctx context.Context) error { return w.callback(ctx, w.jobRow) } +func (w *callbackWorkUnit) UnmarshalJob() error { return nil } type SimpleClientRetryPolicy struct{} diff --git a/internal/maintenance/periodic_job_enqueuer.go b/internal/maintenance/periodic_job_enqueuer.go index fb1c435f..6ccf11c6 100644 --- a/internal/maintenance/periodic_job_enqueuer.go +++ b/internal/maintenance/periodic_job_enqueuer.go @@ -38,7 +38,7 @@ func (ts *PeriodicJobEnqueuerTestSignals) Init() { // river.PeriodicJobArgs, but needs a separate type because the enqueuer is in a // subpackage. type PeriodicJob struct { - ConstructorFunc func() (*riverdriver.JobInsertFastParams, error) + ConstructorFunc func() (*rivertype.JobInsertParams, error) RunOnStart bool ScheduleFunc func(time.Time) time.Time @@ -373,7 +373,7 @@ func (s *PeriodicJobEnqueuer) insertBatch(ctx context.Context, insertParamsMany s.TestSignals.InsertedJobs.Signal(struct{}{}) } -func (s *PeriodicJobEnqueuer) insertParamsFromConstructor(ctx context.Context, constructorFunc func() (*riverdriver.JobInsertFastParams, error), scheduledAt time.Time) (*riverdriver.JobInsertFastParams, bool) { +func (s *PeriodicJobEnqueuer) insertParamsFromConstructor(ctx context.Context, constructorFunc func() (*rivertype.JobInsertParams, error), scheduledAt time.Time) (*riverdriver.JobInsertFastParams, bool) { insertParams, err := constructorFunc() if err != nil { if errors.Is(err, ErrNoJobToInsert) { @@ -389,7 +389,7 @@ func (s *PeriodicJobEnqueuer) insertParamsFromConstructor(ctx context.Context, c insertParams.ScheduledAt = &scheduledAt } - return insertParams, true + return (*riverdriver.JobInsertFastParams)(insertParams), true } const periodicJobEnqueuerVeryLongDuration = 24 * time.Hour diff --git a/internal/maintenance/periodic_job_enqueuer_test.go b/internal/maintenance/periodic_job_enqueuer_test.go index 48677407..189ff9fc 100644 --- a/internal/maintenance/periodic_job_enqueuer_test.go +++ b/internal/maintenance/periodic_job_enqueuer_test.go @@ -40,9 +40,9 @@ func TestPeriodicJobEnqueuer(t *testing.T) { stubSvc := &riversharedtest.TimeStub{} stubSvc.StubNowUTC(time.Now().UTC()) - jobConstructorWithQueueFunc := func(name string, unique bool, queue string) func() (*riverdriver.JobInsertFastParams, error) { - return func() (*riverdriver.JobInsertFastParams, error) { - params := &riverdriver.JobInsertFastParams{ + jobConstructorWithQueueFunc := func(name string, unique bool, queue string) func() (*rivertype.JobInsertParams, error) { + return func() (*rivertype.JobInsertParams, error) { + params := &rivertype.JobInsertParams{ Args: noOpArgs{}, EncodedArgs: []byte("{}"), Kind: name, @@ -66,7 +66,7 @@ func TestPeriodicJobEnqueuer(t *testing.T) { } } - jobConstructorFunc := func(name string, unique bool) func() (*riverdriver.JobInsertFastParams, error) { + jobConstructorFunc := func(name string, unique bool) func() (*rivertype.JobInsertParams, error) { return jobConstructorWithQueueFunc(name, unique, rivercommon.QueueDefault) } @@ -256,7 +256,7 @@ func TestPeriodicJobEnqueuer(t *testing.T) { svc.AddMany([]*PeriodicJob{ // skip this insert when it returns nil: - {ScheduleFunc: periodicIntervalSchedule(time.Second), ConstructorFunc: func() (*riverdriver.JobInsertFastParams, error) { + {ScheduleFunc: periodicIntervalSchedule(time.Second), ConstructorFunc: func() (*rivertype.JobInsertParams, error) { return nil, ErrNoJobToInsert }, RunOnStart: true}, }) diff --git a/internal/maintenance/queue_maintainer_test.go b/internal/maintenance/queue_maintainer_test.go index a68e91a8..5da37def 100644 --- a/internal/maintenance/queue_maintainer_test.go +++ b/internal/maintenance/queue_maintainer_test.go @@ -10,13 +10,13 @@ import ( "github.com/riverqueue/river/internal/riverinternaltest" "github.com/riverqueue/river/internal/riverinternaltest/sharedtx" - "github.com/riverqueue/river/riverdriver" "github.com/riverqueue/river/riverdriver/riverpgxv5" "github.com/riverqueue/river/rivershared/baseservice" "github.com/riverqueue/river/rivershared/riversharedtest" "github.com/riverqueue/river/rivershared/startstop" "github.com/riverqueue/river/rivershared/startstoptest" "github.com/riverqueue/river/rivershared/testsignal" + "github.com/riverqueue/river/rivertype" ) type testService struct { @@ -107,7 +107,7 @@ func TestQueueMaintainer(t *testing.T) { NewPeriodicJobEnqueuer(archetype, &PeriodicJobEnqueuerConfig{ PeriodicJobs: []*PeriodicJob{ { - ConstructorFunc: func() (*riverdriver.JobInsertFastParams, error) { + ConstructorFunc: func() (*rivertype.JobInsertParams, error) { return nil, ErrNoJobToInsert }, ScheduleFunc: cron.Every(15 * time.Minute).Next, diff --git a/internal/workunit/work_unit.go b/internal/workunit/work_unit.go index 746454f2..fc1f541d 100644 --- a/internal/workunit/work_unit.go +++ b/internal/workunit/work_unit.go @@ -15,6 +15,7 @@ import ( // // Implemented by river.wrapperWorkUnit. type WorkUnit interface { + Middleware() []rivertype.WorkerMiddleware NextRetry() time.Time Timeout() time.Duration UnmarshalJob() error diff --git a/job_executor.go b/job_executor.go index e1ce5304..ce730edc 100644 --- a/job_executor.go +++ b/job_executor.go @@ -131,6 +131,7 @@ type jobExecutor struct { ErrorHandler ErrorHandler InformProducerDoneFunc func(jobRow *rivertype.JobRow) JobRow *rivertype.JobRow + GlobalMiddleware []rivertype.WorkerMiddleware SchedulerInterval time.Duration WorkUnit workunit.WorkUnit @@ -197,7 +198,9 @@ func (e *jobExecutor) execute(ctx context.Context) (res *jobExecutorResult) { return &jobExecutorResult{Err: err} } - { + workerMiddleware := e.WorkUnit.Middleware() + + doInner := func(ctx context.Context) error { jobTimeout := e.WorkUnit.Timeout() if jobTimeout == 0 { jobTimeout = e.ClientJobTimeout @@ -210,8 +213,30 @@ func (e *jobExecutor) execute(ctx context.Context) (res *jobExecutorResult) { defer cancel() } - return &jobExecutorResult{Err: e.WorkUnit.Work(ctx)} + if err := e.WorkUnit.Work(ctx); err != nil { + return err + } + + return nil + } + + allMiddleware := make([]rivertype.WorkerMiddleware, 0, len(e.GlobalMiddleware)+len(workerMiddleware)) + allMiddleware = append(allMiddleware, e.GlobalMiddleware...) + allMiddleware = append(allMiddleware, workerMiddleware...) + + if len(allMiddleware) > 0 { + // Wrap middlewares in reverse order so the one defined first is wrapped + // as the outermost function and is first to receive the operation. + for i := len(allMiddleware) - 1; i >= 0; i-- { + middlewareItem := allMiddleware[i] // capture the current middleware item + previousDoInner := doInner // Capture the current doInner function + doInner = func(ctx context.Context) error { + return middlewareItem.Work(ctx, e.JobRow, previousDoInner) + } + } } + + return &jobExecutorResult{Err: doInner(ctx)} } func (e *jobExecutor) invokeErrorHandler(ctx context.Context, res *jobExecutorResult) bool { diff --git a/middleware_defaults.go b/middleware_defaults.go new file mode 100644 index 00000000..3a9195cb --- /dev/null +++ b/middleware_defaults.go @@ -0,0 +1,23 @@ +package river + +import ( + "context" + + "github.com/riverqueue/river/rivertype" +) + +// JobInsertMiddlewareDefaults is an embeddable struct that provides default +// implementations for the rivertype.JobInsertMiddleware. Use of this struct is +// recommended in case rivertype.JobInsertMiddleware is expanded in the future so that +// existing code isn't unexpectedly broken during an upgrade. +type JobInsertMiddlewareDefaults struct{} + +func (d *JobInsertMiddlewareDefaults) InsertMany(ctx context.Context, manyParams []*rivertype.JobInsertParams, doInner func(ctx context.Context) ([]*rivertype.JobInsertResult, error)) ([]*rivertype.JobInsertResult, error) { + return doInner(ctx) +} + +type WorkerMiddlewareDefaults struct{} + +func (d *WorkerMiddlewareDefaults) Work(ctx context.Context, job *rivertype.JobRow, doInner func(ctx context.Context) error) error { + return doInner(ctx) +} diff --git a/middleware_defaults_test.go b/middleware_defaults_test.go new file mode 100644 index 00000000..f1520dac --- /dev/null +++ b/middleware_defaults_test.go @@ -0,0 +1,8 @@ +package river + +import "github.com/riverqueue/river/rivertype" + +var ( + _ rivertype.JobInsertMiddleware = &JobInsertMiddlewareDefaults{} + _ rivertype.WorkerMiddleware = &WorkerMiddlewareDefaults{} +) diff --git a/middleware_test.go b/middleware_test.go new file mode 100644 index 00000000..4e9b17dc --- /dev/null +++ b/middleware_test.go @@ -0,0 +1,29 @@ +package river + +import ( + "context" + + "github.com/riverqueue/river/rivertype" +) + +type overridableJobMiddleware struct { + JobInsertMiddlewareDefaults + WorkerMiddlewareDefaults + + insertManyFunc func(ctx context.Context, manyParams []*rivertype.JobInsertParams, doInner func(ctx context.Context) ([]*rivertype.JobInsertResult, error)) ([]*rivertype.JobInsertResult, error) + workFunc func(ctx context.Context, job *rivertype.JobRow, doInner func(ctx context.Context) error) error +} + +func (m *overridableJobMiddleware) InsertMany(ctx context.Context, manyParams []*rivertype.JobInsertParams, doInner func(ctx context.Context) ([]*rivertype.JobInsertResult, error)) ([]*rivertype.JobInsertResult, error) { + if m.insertManyFunc != nil { + return m.insertManyFunc(ctx, manyParams, doInner) + } + return m.JobInsertMiddlewareDefaults.InsertMany(ctx, manyParams, doInner) +} + +func (m *overridableJobMiddleware) Work(ctx context.Context, job *rivertype.JobRow, doInner func(ctx context.Context) error) error { + if m.workFunc != nil { + return m.workFunc(ctx, job, doInner) + } + return m.WorkerMiddlewareDefaults.Work(ctx, job, doInner) +} diff --git a/periodic_job.go b/periodic_job.go index c3db82f2..55d12f4e 100644 --- a/periodic_job.go +++ b/periodic_job.go @@ -4,7 +4,6 @@ import ( "time" "github.com/riverqueue/river/internal/maintenance" - "github.com/riverqueue/river/riverdriver" "github.com/riverqueue/river/rivershared/util/sliceutil" "github.com/riverqueue/river/rivertype" ) @@ -180,7 +179,7 @@ func (b *PeriodicJobBundle) toInternal(periodicJob *PeriodicJob) *maintenance.Pe opts = periodicJob.opts } return &maintenance.PeriodicJob{ - ConstructorFunc: func() (*riverdriver.JobInsertFastParams, error) { + ConstructorFunc: func() (*rivertype.JobInsertParams, error) { args, options := periodicJob.constructorFunc() if args == nil { return nil, maintenance.ErrNoJobToInsert diff --git a/producer.go b/producer.go index 0a9d8e53..6d6392ad 100644 --- a/producer.go +++ b/producer.go @@ -63,8 +63,9 @@ type producerConfig struct { // LISTEN/NOTIFY, but this provides a fallback. FetchPollInterval time.Duration - JobTimeout time.Duration - MaxWorkers int + GlobalMiddleware []rivertype.WorkerMiddleware + JobTimeout time.Duration + MaxWorkers int // Notifier is a notifier for subscribing to new job inserts and job // control. If nil, the producer will operate in poll-only mode. @@ -579,6 +580,7 @@ func (p *producer) startNewExecutors(workCtx context.Context, jobs []*rivertype. Completer: p.completer, ErrorHandler: p.errorHandler, InformProducerDoneFunc: p.handleWorkerDone, + GlobalMiddleware: p.config.GlobalMiddleware, JobRow: job, SchedulerInterval: p.config.SchedulerInterval, WorkUnit: workUnit, diff --git a/producer_test.go b/producer_test.go index 7c79cb00..ad7f9d8d 100644 --- a/producer_test.go +++ b/producer_test.go @@ -105,7 +105,7 @@ func Test_Producer_CanSafelyCompleteJobsWhileFetchingNewOnes(t *testing.T) { insertParams, err := insertParamsFromConfigArgsAndOptions(archetype, config, WithJobNumArgs{JobNum: i}, nil) require.NoError(err) - params[i] = insertParams + params[i] = (*riverdriver.JobInsertFastParams)(insertParams) } ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) @@ -295,7 +295,7 @@ func testProducer(t *testing.T, makeProducer func(ctx context.Context, t *testin insertParams.ScheduledAt = &bundle.timeBeforeStart } - _, err = bundle.exec.JobInsertFastMany(ctx, []*riverdriver.JobInsertFastParams{insertParams}) + _, err = bundle.exec.JobInsertFastMany(ctx, []*riverdriver.JobInsertFastParams{(*riverdriver.JobInsertFastParams)(insertParams)}) require.NoError(t, err) } diff --git a/rivertype/river_type.go b/rivertype/river_type.go index 1998577a..256531ab 100644 --- a/rivertype/river_type.go +++ b/rivertype/river_type.go @@ -4,6 +4,7 @@ package rivertype import ( + "context" "errors" "time" ) @@ -229,6 +230,49 @@ type AttemptError struct { Trace string `json:"trace"` } +type JobInsertParams struct { + Args JobArgs + CreatedAt *time.Time + EncodedArgs []byte + Kind string + MaxAttempts int + Metadata []byte + Priority int + Queue string + ScheduledAt *time.Time + State JobState + Tags []string + UniqueKey []byte + UniqueStates byte +} + +// JobInsertMiddleware provides an interface for middleware that integrations can +// use to encapsulate common logic around job insertion. +// +// Implementations should embed river.JobMiddlewareDefaults to inherit default +// implementations for phases where no custom code is needed, and for forward +// compatibility in case new functions are added to this interface. +type JobInsertMiddleware interface { + // InsertMany is invoked around a batch insert operation. Implementations + // must always include a call to doInner to call down the middleware stack + // and perfom the batch insertion, and may run custom code before and after. + // + // Returning an error from this function will fail the overarching insert + // operation, even if the inner insertion originally succeeded. + InsertMany(ctx context.Context, manyParams []*JobInsertParams, doInner func(context.Context) ([]*JobInsertResult, error)) ([]*JobInsertResult, error) +} + +type WorkerMiddleware interface { + // Work is invoked after a job's JSON args being unmarshaled and before the + // job is worked. Implementations must always include a call to doInner to + // call down the middleware stack and perfom the batch insertion, and may run + // custom code before and after. + // + // Returning an error from this function will fail the overarching work + // operation, even if the inner work originally succeeded. + Work(ctx context.Context, job *JobRow, doInner func(context.Context) error) error +} + // PeriodicJobHandle is a reference to a dynamically added periodic job // (returned by the use of `Client.PeriodicJobs().Add()`) which can be used to // subsequently remove the periodic job with `Remove()`. diff --git a/work_unit_wrapper.go b/work_unit_wrapper.go index 9e741ee7..9b787ad1 100644 --- a/work_unit_wrapper.go +++ b/work_unit_wrapper.go @@ -29,6 +29,10 @@ func (w *wrapperWorkUnit[T]) NextRetry() time.Time { return w.worker.N func (w *wrapperWorkUnit[T]) Timeout() time.Duration { return w.worker.Timeout(w.job) } func (w *wrapperWorkUnit[T]) Work(ctx context.Context) error { return w.worker.Work(ctx, w.job) } +func (w *wrapperWorkUnit[T]) Middleware() []rivertype.WorkerMiddleware { + return w.worker.Middleware(w.job) +} + func (w *wrapperWorkUnit[T]) UnmarshalJob() error { w.job = &Job[T]{ JobRow: w.jobRow, diff --git a/worker.go b/worker.go index f4af8b37..bf188ab9 100644 --- a/worker.go +++ b/worker.go @@ -6,6 +6,7 @@ import ( "time" "github.com/riverqueue/river/internal/workunit" + "github.com/riverqueue/river/rivertype" ) // Worker is an interface that can perform a job with args of type T. A typical @@ -36,6 +37,9 @@ import ( // In addition to fulfilling the Worker interface, workers must be registered // with the client using the AddWorker function. type Worker[T JobArgs] interface { + // Middleware returns the type-specific middleware for this job. + Middleware(job *Job[T]) []rivertype.WorkerMiddleware + // NextRetry calculates when the next retry for a failed job should take // place given when it was last attempted and its number of attempts, or any // other of the job's properties a user-configured retry policy might want @@ -70,6 +74,8 @@ type Worker[T JobArgs] interface { // struct to make it fulfill the Worker interface with default values. type WorkerDefaults[T JobArgs] struct{} +func (w WorkerDefaults[T]) Middleware(*Job[T]) []rivertype.WorkerMiddleware { return nil } + // NextRetry returns an empty time.Time{} to avoid setting any job or // Worker-specific overrides on the next retry time. This means that the // Client-level retry policy schedule will be used instead.