Skip to content

Commit

Permalink
unify all bulk inserts to a single code path (#610)
Browse files Browse the repository at this point in the history
  • Loading branch information
bgentry authored Sep 23, 2024
1 parent 185dbb8 commit 9be1d11
Showing 1 changed file with 40 additions and 54 deletions.
94 changes: 40 additions & 54 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1441,14 +1441,41 @@ func (c *Client[TTx]) InsertManyTx(ctx context.Context, tx TTx, params []InsertM
}

func (c *Client[TTx]) insertMany(ctx context.Context, tx riverdriver.ExecutorTx, params []InsertManyParams) ([]*rivertype.JobInsertResult, error) {
return c.insertManyShared(ctx, tx, params, func(ctx context.Context, insertParams []*riverdriver.JobInsertFastParams) ([]*rivertype.JobInsertResult, error) {
results, err := tx.JobInsertFastMany(ctx, insertParams)
if err != nil {
return nil, err
}

return sliceutil.Map(results,
func(result *riverdriver.JobInsertFastResult) *rivertype.JobInsertResult {
return (*rivertype.JobInsertResult)(result)
},
), nil
})
}

// The shared code path for all InsertMany methods. It takes a function that
// executes the actual insert operation and allows for different implementations
// of the insert query to be passed in, each mapping their results back to a
// common result type.
//
// TODO(bgentry): this isn't yet used for the single insert path. The only thing
// blocking that is the removal of advisory lock unique inserts.
func (c *Client[TTx]) insertManyShared(
ctx context.Context,
tx riverdriver.ExecutorTx,
params []InsertManyParams,
execute func(context.Context, []*riverdriver.JobInsertFastParams) ([]*rivertype.JobInsertResult, error),
) ([]*rivertype.JobInsertResult, error) {
insertParams, err := c.insertManyParams(params)
if err != nil {
return nil, err
}

jobRows, err := tx.JobInsertFastMany(ctx, insertParams)
inserted, err := execute(ctx, insertParams)
if err != nil {
return nil, err
return inserted, err
}

queues := make([]string, 0, 10)
Expand All @@ -1460,12 +1487,7 @@ func (c *Client[TTx]) insertMany(ctx context.Context, tx riverdriver.ExecutorTx,
if err := c.maybeNotifyInsertForQueues(ctx, tx, queues); err != nil {
return nil, err
}

return sliceutil.Map(jobRows,
func(result *riverdriver.JobInsertFastResult) *rivertype.JobInsertResult {
return (*rivertype.JobInsertResult)(result)
},
), nil
return inserted, nil
}

// Validates input parameters for a batch insert operation and generates a set
Expand Down Expand Up @@ -1516,19 +1538,14 @@ func (c *Client[TTx]) InsertManyFast(ctx context.Context, params []InsertManyPar
return 0, errNoDriverDBPool
}

insertParams, err := c.insertManyFastParams(params)
if err != nil {
return 0, err
}

// Wrap in a transaction in case we need to notify about inserts.
tx, err := c.driver.GetExecutor().Begin(ctx)
if err != nil {
return 0, err
}
defer tx.Rollback(ctx)

inserted, err := c.insertManyFast(ctx, tx, insertParams)
inserted, err := c.insertManyFast(ctx, tx, params)
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -1562,54 +1579,23 @@ func (c *Client[TTx]) InsertManyFast(ctx context.Context, params []InsertManyPar
// unique conflicts cannot be handled gracefully. If a unique constraint is
// violated, the operation will fail and no jobs will be inserted.
func (c *Client[TTx]) InsertManyFastTx(ctx context.Context, tx TTx, params []InsertManyParams) (int, error) {
insertParams, err := c.insertManyFastParams(params)
if err != nil {
return 0, err
}

exec := c.driver.UnwrapExecutor(tx)
return c.insertManyFast(ctx, exec, insertParams)
}

func (c *Client[TTx]) insertManyFast(ctx context.Context, tx riverdriver.ExecutorTx, insertParams []*riverdriver.JobInsertFastParams) (int, error) {
inserted, err := tx.JobInsertFastManyNoReturning(ctx, insertParams)
if err != nil {
return inserted, err
}

queues := make([]string, 0, 10)
for _, params := range insertParams {
if params.State == rivertype.JobStateAvailable {
queues = append(queues, params.Queue)
}
}
if err := c.maybeNotifyInsertForQueues(ctx, tx, queues); err != nil {
return 0, err
}
return inserted, nil
return c.insertManyFast(ctx, exec, params)
}

// Validates input parameters for an a batch insert operation and generates a
// set of batch insert parameters.
func (c *Client[TTx]) insertManyFastParams(params []InsertManyParams) ([]*riverdriver.JobInsertFastParams, error) {
if len(params) < 1 {
return nil, errors.New("no jobs to insert")
}

insertParams := make([]*riverdriver.JobInsertFastParams, len(params))
for i, param := range params {
if err := c.validateJobArgs(param.Args); err != nil {
return nil, err
}

insertParamsItem, _, err := insertParamsFromConfigArgsAndOptions(&c.baseService.Archetype, c.config, param.Args, param.InsertOpts, true)
func (c *Client[TTx]) insertManyFast(ctx context.Context, tx riverdriver.ExecutorTx, params []InsertManyParams) (int, error) {
results, err := c.insertManyShared(ctx, tx, params, func(ctx context.Context, insertParams []*riverdriver.JobInsertFastParams) ([]*rivertype.JobInsertResult, error) {
count, err := tx.JobInsertFastManyNoReturning(ctx, insertParams)
if err != nil {
return nil, err
}
insertParams[i] = insertParamsItem
return make([]*rivertype.JobInsertResult, count), nil
})
if err != nil {
return 0, err
}

return insertParams, nil
return len(results), nil
}

func (c *Client[TTx]) maybeNotifyInsert(ctx context.Context, tx riverdriver.ExecutorTx, state rivertype.JobState, queue string) error {
Expand Down

0 comments on commit 9be1d11

Please sign in to comment.