diff --git a/README.md b/README.md index 5f73e50..c52989f 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,13 @@ const db = new RDSClient({ // connectionStorage: new AsyncLocalStorage(), // If create multiple RDSClient instances with the same connectionStorage, use this key to distinguish between the instances // connectionStorageKey: 'datasource', + + // The timeout for connecting to the MySQL server. (Default: 500 milliseconds) + // connectTimeout: 500, + + // The timeout for waiting for a connection from the connection pool. (Default: 500 milliseconds) + // So max timeout for get a connection is (connectTimeout + poolWaitTimeout) + // poolWaitTimeout: 500, }); ``` diff --git a/src/client.ts b/src/client.ts index a57c1c7..da3d1fc 100644 --- a/src/client.ts +++ b/src/client.ts @@ -1,5 +1,6 @@ import { AsyncLocalStorage } from 'node:async_hooks'; import { promisify } from 'node:util'; +import { setTimeout } from 'node:timers/promises'; import mysql, { Pool } from 'mysql2'; import type { PoolOptions } from 'mysql2'; import type { PoolConnectionPromisify, RDSClientOptions, TransactionContext, TransactionScope } from './types'; @@ -10,12 +11,17 @@ import literals from './literals'; import channels from './channels'; import type { ConnectionMessage, ConnectionEnqueueMessage } from './channels'; import { RDSPoolConfig } from './PoolConfig'; +import { PoolWaitTimeoutError } from './util/PoolWaitTimeout'; + export * from './types'; interface PoolPromisify extends Omit { query(sql: string): Promise; + getConnection(): Promise; + end(): Promise; + _acquiringConnections: any[]; _allConnections: any[]; _freeConnections: any[]; @@ -30,11 +36,25 @@ export interface QueryOptions { } export class RDSClient extends Operator { - static get literals() { return literals; } - static get escape() { return mysql.escape; } - static get escapeId() { return mysql.escapeId; } - static get format() { return mysql.format; } - static get raw() { return mysql.raw; } + static get literals() { + return literals; + } + + static get escape() { + return mysql.escape; + } + + static get escapeId() { + return mysql.escapeId; + } + + static get format() { + return mysql.format; + } + + static get raw() { + return mysql.raw; + } static #DEFAULT_STORAGE_KEY = Symbol('RDSClient#storage#default'); static #TRANSACTION_NEST_COUNT = Symbol('RDSClient#transaction#nestCount'); @@ -42,9 +62,11 @@ export class RDSClient extends Operator { #pool: PoolPromisify; #connectionStorage: AsyncLocalStorage; #connectionStorageKey: string | symbol; + #poolWaitTimeout: number; constructor(options: RDSClientOptions) { super(); + options.connectTimeout = options.connectTimeout ?? 500; const { connectionStorage, connectionStorageKey, ...mysqlOptions } = options; // get connection options from getConnectionConfig method every time if (mysqlOptions.getConnectionConfig) { @@ -61,6 +83,7 @@ export class RDSClient extends Operator { }); this.#connectionStorage = connectionStorage || new AsyncLocalStorage(); this.#connectionStorageKey = connectionStorageKey || RDSClient.#DEFAULT_STORAGE_KEY; + this.#poolWaitTimeout = options.poolWaitTimeout ?? 500; // https://github.com/mysqljs/mysql#pool-events this.#pool.on('connection', (connection: PoolConnectionPromisify) => { channels.connectionNew.publish({ @@ -129,9 +152,30 @@ export class RDSClient extends Operator { }; } + async waitPoolConnection(abortSignal: AbortSignal) { + const now = performance.now(); + await setTimeout(this.#poolWaitTimeout, undefined, { signal: abortSignal }); + return performance.now() - now; + } + + async getConnectionWithTimeout() { + const connPromise = this.#pool.getConnection(); + const timeoutAbortController = new AbortController(); + const timeoutPromise = this.waitPoolConnection(timeoutAbortController.signal); + const connOrTimeout = await Promise.race([ connPromise, timeoutPromise ]); + if (typeof connOrTimeout === 'number') { + connPromise.then(conn => { + conn.release(); + }); + throw new PoolWaitTimeoutError(`get connection timeout after ${connOrTimeout}ms`); + } + timeoutAbortController.abort(); + return connPromise; + } + async getConnection() { try { - const _conn = await this.#pool.getConnection(); + const _conn = await this.getConnectionWithTimeout(); const conn = new RDSConnection(_conn); if (this.beforeQueryHandlers.length > 0) { for (const handler of this.beforeQueryHandlers) { diff --git a/src/connection.ts b/src/connection.ts index 3d1f091..2daaa39 100644 --- a/src/connection.ts +++ b/src/connection.ts @@ -1,3 +1,4 @@ +import assert from 'node:assert'; import { promisify } from 'node:util'; import { Operator } from './operator'; import type { PoolConnectionPromisify } from './types'; @@ -6,8 +7,11 @@ const kWrapToRDS = Symbol('kWrapToRDS'); export class RDSConnection extends Operator { conn: PoolConnectionPromisify; + #released: boolean; + constructor(conn: PoolConnectionPromisify) { super(conn); + this.#released = false; this.conn = conn; if (!this.conn[kWrapToRDS]) { [ @@ -23,6 +27,8 @@ export class RDSConnection extends Operator { } release() { + assert(!this.#released, 'connection was released'); + this.#released = true; return this.conn.release(); } diff --git a/src/transaction.ts b/src/transaction.ts index 7174946..c655b40 100644 --- a/src/transaction.ts +++ b/src/transaction.ts @@ -1,12 +1,16 @@ import type { RDSConnection } from './connection'; import { Operator } from './operator'; +let id = 0; export class RDSTransaction extends Operator { isCommit = false; isRollback = false; conn: RDSConnection | null; + id: number; + constructor(conn: RDSConnection) { super(conn.conn); + this.id = id++; this.conn = conn; } diff --git a/src/types.ts b/src/types.ts index ab4bfb0..f273a35 100644 --- a/src/types.ts +++ b/src/types.ts @@ -8,6 +8,7 @@ export interface RDSClientOptions extends PoolOptions { connectionStorageKey?: string; connectionStorage?: AsyncLocalStorage>; getConnectionConfig?: GetConnectionConfig; + poolWaitTimeout?: number; } export interface PoolConnectionPromisify extends Omit { diff --git a/src/util/PoolWaitTimeout.ts b/src/util/PoolWaitTimeout.ts new file mode 100644 index 0000000..bf4d020 --- /dev/null +++ b/src/util/PoolWaitTimeout.ts @@ -0,0 +1,6 @@ +export class PoolWaitTimeoutError extends Error { + constructor(...args) { + super(...args); + this.name = 'PoolWaitTimeoutError'; + } +} diff --git a/test/client.test.ts b/test/client.test.ts index bf9a660..c614e6c 100644 --- a/test/client.test.ts +++ b/test/client.test.ts @@ -1,6 +1,7 @@ import { AsyncLocalStorage } from 'node:async_hooks'; import { strict as assert } from 'node:assert'; import fs from 'node:fs/promises'; +import { setTimeout } from 'node:timers/promises'; import path from 'node:path'; import mm from 'mm'; import { RDSTransaction } from '../src/transaction'; @@ -298,8 +299,9 @@ describe('test/client.test.ts', () => { // recovered after unlock. await conn.query('select * from `myrds-test-user` limit 1;'); } catch (err) { - conn.release(); throw err; + } finally { + conn.release(); } }); @@ -353,7 +355,8 @@ describe('test/client.test.ts', () => { }); it('should throw rollback error with cause error when rollback failed', async () => { - mm(RDSTransaction.prototype, 'rollback', async () => { + mm(RDSTransaction.prototype, 'rollback', async function(this: RDSTransaction) { + this.conn!.release(); throw new Error('fake rollback error'); }); await assert.rejects( @@ -501,7 +504,9 @@ describe('test/client.test.ts', () => { }); }; - const [ p1Res, p2Res ] = await Promise.all([ p1(), p2().catch(err => err) ]); + const [ p1Res, p2Res ] = await Promise.all([ p1(), p2().catch(err => { + return err; + }) ]); assert.strictEqual(p1Res, true); assert.strictEqual(p2Res.code, 'ER_PARSE_ERROR'); const rows = await db.query('select * from ?? where email=? order by id', @@ -680,6 +685,7 @@ describe('test/client.test.ts', () => { }); return db; }); + conn.release(); assert(connQuerySql); assert(!transactionQuerySql); }); @@ -1493,4 +1499,98 @@ describe('test/client.test.ts', () => { assert.equal(counter2After, 4); }); }); + + describe('PoolWaitTimeout', () => { + async function longQuery(timeout?: number) { + await db.beginTransactionScope(async conn => { + await setTimeout(timeout ?? 1000); + await conn.query('SELECT 1+1'); + }); + } + + it('should throw error if pool wait timeout', async () => { + const tasks: Array> = []; + for (let i = 0; i < 10; i++) { + tasks.push(longQuery()); + } + const tasksPromise = Promise.all(tasks); + await assert.rejects(async () => { + await longQuery(); + }, /get connection timeout after/); + await tasksPromise; + }); + + it('should release conn to pool', async () => { + const tasks: Array> = []; + const timeoutTasks: Array> = []; + // 1. fill the pool + for (let i = 0; i < 10; i++) { + tasks.push(longQuery()); + } + // 2. add more conn and wait for timeout + for (let i = 0; i < 10; i++) { + timeoutTasks.push(longQuery()); + } + const [ succeedTasks, failedTasks ] = await Promise.all([ + Promise.allSettled(tasks), + Promise.allSettled(timeoutTasks), + ]); + const succeedCount = succeedTasks.filter(t => t.status === 'fulfilled').length; + assert.equal(succeedCount, 10); + + const failedCount = failedTasks.filter(t => t.status === 'rejected').length; + assert.equal(failedCount, 10); + + // 3. after pool empty, create new tasks + const retryTasks: Array> = []; + for (let i = 0; i < 10; i++) { + retryTasks.push(longQuery()); + } + await Promise.all(retryTasks); + }); + + it('should not wait too long', async () => { + const tasks: Array> = []; + const timeoutTasks: Array> = []; + const fastTasks: Array> = []; + const start = performance.now(); + // 1. fill the pool + for (let i = 0; i < 10; i++) { + tasks.push(longQuery()); + } + const tasksPromise = Promise.allSettled(tasks); + // 2. add more conn and wait for timeout + for (let i = 0; i < 10; i++) { + timeoutTasks.push(longQuery()); + } + const timeoutTasksPromise = Promise.allSettled(timeoutTasks); + await setTimeout(600); + // 3. add fast query + for (let i = 0; i < 10; i++) { + fastTasks.push(longQuery(1)); + } + const fastTasksPromise = Promise.allSettled(fastTasks); + const [ succeedTasks, failedTasks, fastTaskResults ] = await Promise.all([ + tasksPromise, + timeoutTasksPromise, + fastTasksPromise, + ]); + const duration = performance.now() - start; + const succeedCount = succeedTasks.filter(t => t.status === 'fulfilled').length; + assert.equal(succeedCount, 10); + + const failedCount = failedTasks.filter(t => t.status === 'rejected').length; + assert.equal(failedCount, 10); + + const faskTaskSucceedCount = fastTaskResults.filter(t => t.status === 'fulfilled').length; + assert.equal(faskTaskSucceedCount, 10); + + // - 10 long queries cost 1000ms + // - 10 timeout queries should be timeout in long query execution so not cost time + // - 10 fast queries wait long query to finish, cost 1ms + // 1000ms + 0ms + 1ms < 1100ms + assert(duration < 1100); + }); + + }); });