Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add middleware system for jobs #632

Merged
merged 1 commit into from
Oct 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

⚠️ Version 0.13.0 removes the original advisory lock based unique jobs implementation that was deprecated in v0.12.0. See details in the note below or the v0.12.0 release notes.

### Added

- A middleware system was added for job insertion and execution, providing the ability to extract shared functionality across workers. Both `JobInsertMiddleware` and `WorkerMiddleware` can be configured globally on the `Client`, and `WorkerMiddleware` can also be added on a per-worker basis using the new `Middleware` method on `Worker[T]`. Middleware can be useful for logging, telemetry, or for building higher level abstractions on top of base River functionality.

Despite the interface expansion, users should not encounter any breakage if they're embedding the `WorkerDefaults` type in their workers as recommended. [PR #632](https://github.com/riverqueue/river/pull/632).

### Changed

- **Breaking change:** The advisory lock unique jobs implementation which was deprecated in v0.12.0 has been removed. Users of that feature should first upgrade to v0.12.1 to ensure they don't see any warning logs about using the deprecated advisory lock uniqueness. The new, faster unique implementation will be used automatically as long as the `UniqueOpts.ByState` list hasn't been customized to remove [required states](https://riverqueue.com/docs/unique-jobs#unique-by-state) (`pending`, `scheduled`, `available`, and `running`). As of this release, customizing `ByState` without these required states returns an error. [PR #614](https://github.com/riverqueue/river/pull/614).
Expand Down
68 changes: 49 additions & 19 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ type Config struct {
// deployments.
JobCleanerTimeout time.Duration

// JobInsertMiddleware are optional functions that can be called around job
// insertion.
JobInsertMiddleware []rivertype.JobInsertMiddleware

// JobTimeout is the maximum amount of time a job is allowed to run before its
// context is cancelled. A timeout of zero means JobTimeoutDefault will be
// used, whereas a value of -1 means the job's context will not be cancelled
Expand Down Expand Up @@ -235,6 +239,10 @@ type Config struct {
// (i.e. That it wasn't forgotten by accident.)
Workers *Workers

// WorkerMiddleware are optional functions that can be called around
// all job executions.
WorkerMiddleware []rivertype.WorkerMiddleware

// Scheduler run interval. Shared between the scheduler and producer/job
// executors, but not currently exposed for configuration.
schedulerInterval time.Duration
Expand Down Expand Up @@ -467,6 +475,7 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client
FetchCooldown: valutil.ValOrDefault(config.FetchCooldown, FetchCooldownDefault),
FetchPollInterval: valutil.ValOrDefault(config.FetchPollInterval, FetchPollIntervalDefault),
ID: valutil.ValOrDefaultFunc(config.ID, func() string { return defaultClientID(time.Now().UTC()) }),
JobInsertMiddleware: config.JobInsertMiddleware,
JobTimeout: valutil.ValOrDefault(config.JobTimeout, JobTimeoutDefault),
Logger: logger,
MaxAttempts: valutil.ValOrDefault(config.MaxAttempts, MaxAttemptsDefault),
Expand All @@ -478,6 +487,7 @@ func NewClient[TTx any](driver riverdriver.Driver[TTx], config *Config) (*Client
RetryPolicy: retryPolicy,
TestOnly: config.TestOnly,
Workers: config.Workers,
WorkerMiddleware: config.WorkerMiddleware,
schedulerInterval: valutil.ValOrDefault(config.schedulerInterval, maintenance.JobSchedulerIntervalDefault),
time: config.time,
}
Expand Down Expand Up @@ -1165,7 +1175,7 @@ func (c *Client[TTx]) ID() string {
return c.config.ID
}

func insertParamsFromConfigArgsAndOptions(archetype *baseservice.Archetype, config *Config, args JobArgs, insertOpts *InsertOpts) (*riverdriver.JobInsertFastParams, error) {
func insertParamsFromConfigArgsAndOptions(archetype *baseservice.Archetype, config *Config, args JobArgs, insertOpts *InsertOpts) (*rivertype.JobInsertParams, error) {
encodedArgs, err := json.Marshal(args)
if err != nil {
return nil, fmt.Errorf("error marshaling args to JSON: %w", err)
Expand Down Expand Up @@ -1227,13 +1237,13 @@ func insertParamsFromConfigArgsAndOptions(archetype *baseservice.Archetype, conf
metadata = []byte("{}")
}

insertParams := &riverdriver.JobInsertFastParams{
insertParams := &rivertype.JobInsertParams{
Args: args,
CreatedAt: createdAt,
EncodedArgs: json.RawMessage(encodedArgs),
EncodedArgs: encodedArgs,
Kind: args.Kind(),
MaxAttempts: maxAttempts,
Metadata: json.RawMessage(metadata),
Metadata: metadata,
Priority: priority,
Queue: queue,
State: rivertype.JobStateAvailable,
Expand Down Expand Up @@ -1436,39 +1446,58 @@ func (c *Client[TTx]) insertMany(ctx context.Context, tx riverdriver.ExecutorTx,
func (c *Client[TTx]) insertManyShared(
ctx context.Context,
tx riverdriver.ExecutorTx,
params []InsertManyParams,
rawParams []InsertManyParams,
execute func(context.Context, []*riverdriver.JobInsertFastParams) ([]*rivertype.JobInsertResult, error),
) ([]*rivertype.JobInsertResult, error) {
insertParams, err := c.insertManyParams(params)
insertParams, err := c.insertManyParams(rawParams)
if err != nil {
return nil, err
}

inserted, err := execute(ctx, insertParams)
if err != nil {
return inserted, err
}
doInner := func(ctx context.Context) ([]*rivertype.JobInsertResult, error) {
finalInsertParams := sliceutil.Map(insertParams, func(params *rivertype.JobInsertParams) *riverdriver.JobInsertFastParams {
return (*riverdriver.JobInsertFastParams)(params)
})
results, err := execute(ctx, finalInsertParams)
if err != nil {
return results, err
}

queues := make([]string, 0, 10)
for _, params := range insertParams {
if params.State == rivertype.JobStateAvailable {
queues = append(queues, params.Queue)
queues := make([]string, 0, 10)
for _, params := range insertParams {
if params.State == rivertype.JobStateAvailable {
queues = append(queues, params.Queue)
}
}
if err := c.maybeNotifyInsertForQueues(ctx, tx, queues); err != nil {
return nil, err
}
return results, nil
}
if err := c.maybeNotifyInsertForQueues(ctx, tx, queues); err != nil {
return nil, err

if len(c.config.JobInsertMiddleware) > 0 {
// Wrap middlewares in reverse order so the one defined first is wrapped
// as the outermost function and is first to receive the operation.
for i := len(c.config.JobInsertMiddleware) - 1; i >= 0; i-- {
middlewareItem := c.config.JobInsertMiddleware[i] // capture the current middleware item
previousDoInner := doInner // Capture the current doInner function
doInner = func(ctx context.Context) ([]*rivertype.JobInsertResult, error) {
return middlewareItem.InsertMany(ctx, insertParams, previousDoInner)
}
}
}
return inserted, nil

return doInner(ctx)
}

// Validates input parameters for a batch insert operation and generates a set
// of batch insert parameters.
func (c *Client[TTx]) insertManyParams(params []InsertManyParams) ([]*riverdriver.JobInsertFastParams, error) {
func (c *Client[TTx]) insertManyParams(params []InsertManyParams) ([]*rivertype.JobInsertParams, error) {
if len(params) < 1 {
return nil, errors.New("no jobs to insert")
}

insertParams := make([]*riverdriver.JobInsertFastParams, len(params))
insertParams := make([]*rivertype.JobInsertParams, len(params))
for i, param := range params {
if err := c.validateJobArgs(param.Args); err != nil {
return nil, err
Expand Down Expand Up @@ -1665,6 +1694,7 @@ func (c *Client[TTx]) addProducer(queueName string, queueConfig QueueConfig) *pr
ErrorHandler: c.config.ErrorHandler,
FetchCooldown: c.config.FetchCooldown,
FetchPollInterval: c.config.FetchPollInterval,
GlobalMiddleware: c.config.WorkerMiddleware,
JobTimeout: c.config.JobTimeout,
MaxWorkers: queueConfig.MaxWorkers,
Notifier: c.notifier,
Expand Down
157 changes: 155 additions & 2 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/jackc/pgx/v5/stdlib"
"github.com/robfig/cron/v3"
"github.com/stretchr/testify/require"
"github.com/tidwall/sjson"

"github.com/riverqueue/river/internal/dbunique"
"github.com/riverqueue/river/internal/maintenance"
Expand Down Expand Up @@ -589,6 +590,90 @@ func Test_Client(t *testing.T) {
require.Equal(t, `relation "river_job" does not exist`, pgErr.Message)
})

t.Run("WithWorkerMiddleware", func(t *testing.T) {
t.Parallel()

_, bundle := setup(t)
middlewareCalled := false

type privateKey string

middleware := &overridableJobMiddleware{
workFunc: func(ctx context.Context, job *rivertype.JobRow, doInner func(ctx context.Context) error) error {
ctx = context.WithValue(ctx, privateKey("middleware"), "called")
middlewareCalled = true
return doInner(ctx)
},
}
bundle.config.WorkerMiddleware = []rivertype.WorkerMiddleware{middleware}

AddWorker(bundle.config.Workers, WorkFunc(func(ctx context.Context, job *Job[callbackArgs]) error {
require.Equal(t, "called", ctx.Value(privateKey("middleware")))
return nil
}))

driver := riverpgxv5.New(bundle.dbPool)
client, err := NewClient(driver, bundle.config)
require.NoError(t, err)

subscribeChan := subscribe(t, client)
startClient(ctx, t, client)

result, err := client.Insert(ctx, callbackArgs{}, nil)
require.NoError(t, err)

event := riversharedtest.WaitOrTimeout(t, subscribeChan)
require.Equal(t, EventKindJobCompleted, event.Kind)
require.Equal(t, result.Job.ID, event.Job.ID)
require.True(t, middlewareCalled)
})

t.Run("WithWorkerMiddlewareOnWorker", func(t *testing.T) {
t.Parallel()

_, bundle := setup(t)
middlewareCalled := false

type privateKey string

worker := &workerWithMiddleware[callbackArgs]{
workFunc: func(ctx context.Context, job *Job[callbackArgs]) error {
require.Equal(t, "called", ctx.Value(privateKey("middleware")))
return nil
},
middlewareFunc: func(job *Job[callbackArgs]) []rivertype.WorkerMiddleware {
require.Equal(t, "middleware_test", job.Args.Name, "JSON should be decoded before middleware is called")

return []rivertype.WorkerMiddleware{
&overridableJobMiddleware{
workFunc: func(ctx context.Context, job *rivertype.JobRow, doInner func(ctx context.Context) error) error {
ctx = context.WithValue(ctx, privateKey("middleware"), "called")
middlewareCalled = true
return doInner(ctx)
},
},
}
},
}

AddWorker(bundle.config.Workers, worker)

driver := riverpgxv5.New(bundle.dbPool)
client, err := NewClient(driver, bundle.config)
require.NoError(t, err)

subscribeChan := subscribe(t, client)
startClient(ctx, t, client)

result, err := client.Insert(ctx, callbackArgs{Name: "middleware_test"}, nil)
require.NoError(t, err)

event := riversharedtest.WaitOrTimeout(t, subscribeChan)
require.Equal(t, EventKindJobCompleted, event.Kind)
require.Equal(t, result.Job.ID, event.Job.ID)
require.True(t, middlewareCalled)
})

t.Run("PauseAndResumeSingleQueue", func(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -835,6 +920,20 @@ func Test_Client(t *testing.T) {
})
}

type workerWithMiddleware[T JobArgs] struct {
WorkerDefaults[T]
workFunc func(context.Context, *Job[T]) error
middlewareFunc func(*Job[T]) []rivertype.WorkerMiddleware
}

func (w *workerWithMiddleware[T]) Work(ctx context.Context, job *Job[T]) error {
return w.workFunc(ctx, job)
}

func (w *workerWithMiddleware[T]) Middleware(job *Job[T]) []rivertype.WorkerMiddleware {
return w.middlewareFunc(job)
}

func Test_Client_Stop(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -2420,6 +2519,48 @@ func Test_Client_InsertManyTx(t *testing.T) {
require.Len(t, results, 1)
})

t.Run("WithJobInsertMiddleware", func(t *testing.T) {
t.Parallel()

_, bundle := setup(t)
config := newTestConfig(t, nil)
config.Queues = nil

insertCalled := false
var innerResults []*rivertype.JobInsertResult

middleware := &overridableJobMiddleware{
insertManyFunc: func(ctx context.Context, manyParams []*rivertype.JobInsertParams, doInner func(ctx context.Context) ([]*rivertype.JobInsertResult, error)) ([]*rivertype.JobInsertResult, error) {
insertCalled = true
var err error
for _, params := range manyParams {
params.Metadata, err = sjson.SetBytes(params.Metadata, "middleware", "called")
require.NoError(t, err)
}

results, err := doInner(ctx)
require.NoError(t, err)
innerResults = results
return results, nil
},
}

config.JobInsertMiddleware = []rivertype.JobInsertMiddleware{middleware}
driver := riverpgxv5.New(nil)
client, err := NewClient(driver, config)
require.NoError(t, err)

results, err := client.InsertManyTx(ctx, bundle.tx, []InsertManyParams{{Args: noOpArgs{}}})
require.NoError(t, err)
require.Len(t, results, 1)

require.True(t, insertCalled)
require.Len(t, innerResults, 1)
require.Len(t, results, 1)
require.Equal(t, innerResults[0].Job.ID, results[0].Job.ID)
require.JSONEq(t, `{"middleware": "called"}`, string(results[0].Job.Metadata))
})

t.Run("WithUniqueOpts", func(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -2998,7 +3139,7 @@ func Test_Client_ErrorHandler(t *testing.T) {
// unknown job.
insertParams, err := insertParamsFromConfigArgsAndOptions(&client.baseService.Archetype, config, unregisteredJobArgs{}, nil)
require.NoError(t, err)
_, err = client.driver.GetExecutor().JobInsertFastMany(ctx, []*riverdriver.JobInsertFastParams{insertParams})
_, err = client.driver.GetExecutor().JobInsertFastMany(ctx, []*riverdriver.JobInsertFastParams{(*riverdriver.JobInsertFastParams)(insertParams)})
require.NoError(t, err)

riversharedtest.WaitOrTimeout(t, bundle.SubscribeChan)
Expand Down Expand Up @@ -4600,7 +4741,7 @@ func Test_Client_UnknownJobKindErrorsTheJob(t *testing.T) {

insertParams, err := insertParamsFromConfigArgsAndOptions(&client.baseService.Archetype, config, unregisteredJobArgs{}, nil)
require.NoError(err)
insertedResults, err := client.driver.GetExecutor().JobInsertFastMany(ctx, []*riverdriver.JobInsertFastParams{insertParams})
insertedResults, err := client.driver.GetExecutor().JobInsertFastMany(ctx, []*riverdriver.JobInsertFastParams{(*riverdriver.JobInsertFastParams)(insertParams)})
require.NoError(err)

insertedResult := insertedResults[0]
Expand Down Expand Up @@ -4770,6 +4911,14 @@ func Test_NewClient_Overrides(t *testing.T) {

retryPolicy := &DefaultClientRetryPolicy{}

type noOpInsertMiddleware struct {
JobInsertMiddlewareDefaults
}

type noOpWorkerMiddleware struct {
WorkerMiddlewareDefaults
}

client, err := NewClient(riverpgxv5.New(dbPool), &Config{
AdvisoryLockPrefix: 123_456,
CancelledJobRetentionPeriod: 1 * time.Hour,
Expand All @@ -4778,13 +4927,15 @@ func Test_NewClient_Overrides(t *testing.T) {
ErrorHandler: errorHandler,
FetchCooldown: 123 * time.Millisecond,
FetchPollInterval: 124 * time.Millisecond,
JobInsertMiddleware: []rivertype.JobInsertMiddleware{&noOpInsertMiddleware{}},
JobTimeout: 125 * time.Millisecond,
Logger: logger,
MaxAttempts: 5,
Queues: map[string]QueueConfig{QueueDefault: {MaxWorkers: 1}},
RetryPolicy: retryPolicy,
TestOnly: true, // disables staggered start in maintenance services
Workers: workers,
WorkerMiddleware: []rivertype.WorkerMiddleware{&noOpWorkerMiddleware{}},
})
require.NoError(t, err)

Expand All @@ -4803,10 +4954,12 @@ func Test_NewClient_Overrides(t *testing.T) {
require.Equal(t, errorHandler, client.config.ErrorHandler)
require.Equal(t, 123*time.Millisecond, client.config.FetchCooldown)
require.Equal(t, 124*time.Millisecond, client.config.FetchPollInterval)
require.Len(t, client.config.JobInsertMiddleware, 1)
require.Equal(t, 125*time.Millisecond, client.config.JobTimeout)
require.Equal(t, logger, client.baseService.Logger)
require.Equal(t, 5, client.config.MaxAttempts)
require.Equal(t, retryPolicy, client.config.RetryPolicy)
require.Len(t, client.config.WorkerMiddleware, 1)
}

func Test_NewClient_MissingParameters(t *testing.T) {
Expand Down
Loading
Loading