Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: impl PoolWaitTimeoutError #7

Merged
merged 1 commit into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ const db = new RDSClient({
// connectionStorage: new AsyncLocalStorage(),
// If create multiple RDSClient instances with the same connectionStorage, use this key to distinguish between the instances
// connectionStorageKey: 'datasource',

// The timeout for connecting to the MySQL server. (Default: 500 milliseconds)
// connectTimeout: 500,

// The timeout for waiting for a connection from the connection pool. (Default: 500 milliseconds)
// So max timeout for get a connection is (connectTimeout + poolWaitTimeout)
// poolWaitTimeout: 500,
});
```

Expand Down
56 changes: 50 additions & 6 deletions src/client.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { AsyncLocalStorage } from 'node:async_hooks';
import { promisify } from 'node:util';
import { setTimeout } from 'node:timers/promises';
import mysql, { Pool } from 'mysql2';
import type { PoolOptions } from 'mysql2';
import type { PoolConnectionPromisify, RDSClientOptions, TransactionContext, TransactionScope } from './types';
Expand All @@ -10,12 +11,17 @@ import literals from './literals';
import channels from './channels';
import type { ConnectionMessage, ConnectionEnqueueMessage } from './channels';
import { RDSPoolConfig } from './PoolConfig';
import { PoolWaitTimeoutError } from './util/PoolWaitTimeout';

export * from './types';

interface PoolPromisify extends Omit<Pool, 'query'> {
query(sql: string): Promise<any>;

getConnection(): Promise<PoolConnectionPromisify>;

end(): Promise<void>;

_acquiringConnections: any[];
_allConnections: any[];
_freeConnections: any[];
Expand All @@ -30,21 +36,37 @@ export interface QueryOptions {
}

export class RDSClient extends Operator {
static get literals() { return literals; }
static get escape() { return mysql.escape; }
static get escapeId() { return mysql.escapeId; }
static get format() { return mysql.format; }
static get raw() { return mysql.raw; }
static get literals() {
return literals;
}

static get escape() {
return mysql.escape;
}

static get escapeId() {
return mysql.escapeId;
}

static get format() {
return mysql.format;
}

static get raw() {
return mysql.raw;
}

static #DEFAULT_STORAGE_KEY = Symbol('RDSClient#storage#default');
static #TRANSACTION_NEST_COUNT = Symbol('RDSClient#transaction#nestCount');

#pool: PoolPromisify;
#connectionStorage: AsyncLocalStorage<TransactionContext>;
#connectionStorageKey: string | symbol;
#poolWaitTimeout: number;

constructor(options: RDSClientOptions) {
super();
options.connectTimeout = options.connectTimeout ?? 500;
const { connectionStorage, connectionStorageKey, ...mysqlOptions } = options;
// get connection options from getConnectionConfig method every time
if (mysqlOptions.getConnectionConfig) {
Expand All @@ -61,6 +83,7 @@ export class RDSClient extends Operator {
});
this.#connectionStorage = connectionStorage || new AsyncLocalStorage();
this.#connectionStorageKey = connectionStorageKey || RDSClient.#DEFAULT_STORAGE_KEY;
this.#poolWaitTimeout = options.poolWaitTimeout ?? 500;
// https://github.com/mysqljs/mysql#pool-events
this.#pool.on('connection', (connection: PoolConnectionPromisify) => {
channels.connectionNew.publish({
Expand Down Expand Up @@ -129,9 +152,30 @@ export class RDSClient extends Operator {
};
}

async waitPoolConnection(abortSignal: AbortSignal) {
const now = performance.now();
await setTimeout(this.#poolWaitTimeout, undefined, { signal: abortSignal });
return performance.now() - now;
}

async getConnectionWithTimeout() {
const connPromise = this.#pool.getConnection();
const timeoutAbortController = new AbortController();
const timeoutPromise = this.waitPoolConnection(timeoutAbortController.signal);
const connOrTimeout = await Promise.race([ connPromise, timeoutPromise ]);
if (typeof connOrTimeout === 'number') {
connPromise.then(conn => {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续得去给 getConnection 加一个 abortSignal 支持,就不需要写下面这种 hack 逻辑。
connPromise 这段逻辑代码其实也奇怪,一拿到 conn 就马上释放。

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mysql 还停留在 callback 的写法,这个改动比较难了。

conn.release();
});
throw new PoolWaitTimeoutError(`get connection timeout after ${connOrTimeout}ms`);
}
timeoutAbortController.abort();
return connPromise;
}

async getConnection() {
try {
const _conn = await this.#pool.getConnection();
const _conn = await this.getConnectionWithTimeout();
const conn = new RDSConnection(_conn);
if (this.beforeQueryHandlers.length > 0) {
for (const handler of this.beforeQueryHandlers) {
Expand Down
6 changes: 6 additions & 0 deletions src/connection.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import assert from 'node:assert';
import { promisify } from 'node:util';
import { Operator } from './operator';
import type { PoolConnectionPromisify } from './types';
Expand All @@ -6,8 +7,11 @@ const kWrapToRDS = Symbol('kWrapToRDS');

export class RDSConnection extends Operator {
conn: PoolConnectionPromisify;
#released: boolean;

constructor(conn: PoolConnectionPromisify) {
super(conn);
this.#released = false;
this.conn = conn;
if (!this.conn[kWrapToRDS]) {
[
Expand All @@ -23,6 +27,8 @@ export class RDSConnection extends Operator {
}

release() {
assert(!this.#released, 'connection was released');
this.#released = true;
return this.conn.release();
}

Expand Down
4 changes: 4 additions & 0 deletions src/transaction.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import type { RDSConnection } from './connection';
import { Operator } from './operator';

let id = 0;
export class RDSTransaction extends Operator {
isCommit = false;
isRollback = false;
conn: RDSConnection | null;
id: number;

constructor(conn: RDSConnection) {
super(conn.conn);
this.id = id++;
this.conn = conn;
}

Expand Down
1 change: 1 addition & 0 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export interface RDSClientOptions extends PoolOptions {
connectionStorageKey?: string;
connectionStorage?: AsyncLocalStorage<Record<PropertyKey, RDSTransaction>>;
getConnectionConfig?: GetConnectionConfig;
poolWaitTimeout?: number;
}

export interface PoolConnectionPromisify extends Omit<PoolConnection, 'query'> {
Expand Down
6 changes: 6 additions & 0 deletions src/util/PoolWaitTimeout.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
export class PoolWaitTimeoutError extends Error {
constructor(...args) {
super(...args);
this.name = 'PoolWaitTimeoutError';
}
}
106 changes: 103 additions & 3 deletions test/client.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { AsyncLocalStorage } from 'node:async_hooks';
import { strict as assert } from 'node:assert';
import fs from 'node:fs/promises';
import { setTimeout } from 'node:timers/promises';
import path from 'node:path';
import mm from 'mm';
import { RDSTransaction } from '../src/transaction';
Expand Down Expand Up @@ -298,8 +299,9 @@ describe('test/client.test.ts', () => {
// recovered after unlock.
await conn.query('select * from `myrds-test-user` limit 1;');
} catch (err) {
conn.release();
throw err;
} finally {
conn.release();
}
});

Expand Down Expand Up @@ -353,7 +355,8 @@ describe('test/client.test.ts', () => {
});

it('should throw rollback error with cause error when rollback failed', async () => {
mm(RDSTransaction.prototype, 'rollback', async () => {
mm(RDSTransaction.prototype, 'rollback', async function(this: RDSTransaction) {
this.conn!.release();
throw new Error('fake rollback error');
});
await assert.rejects(
Expand Down Expand Up @@ -501,7 +504,9 @@ describe('test/client.test.ts', () => {
});
};

const [ p1Res, p2Res ] = await Promise.all([ p1(), p2().catch(err => err) ]);
const [ p1Res, p2Res ] = await Promise.all([ p1(), p2().catch(err => {
return err;
}) ]);
assert.strictEqual(p1Res, true);
assert.strictEqual(p2Res.code, 'ER_PARSE_ERROR');
const rows = await db.query('select * from ?? where email=? order by id',
Expand Down Expand Up @@ -680,6 +685,7 @@ describe('test/client.test.ts', () => {
});
return db;
});
conn.release();
assert(connQuerySql);
assert(!transactionQuerySql);
});
Expand Down Expand Up @@ -1493,4 +1499,98 @@ describe('test/client.test.ts', () => {
assert.equal(counter2After, 4);
});
});

describe('PoolWaitTimeout', () => {
async function longQuery(timeout?: number) {
await db.beginTransactionScope(async conn => {
await setTimeout(timeout ?? 1000);
await conn.query('SELECT 1+1');
});
}

it('should throw error if pool wait timeout', async () => {
const tasks: Array<Promise<void>> = [];
for (let i = 0; i < 10; i++) {
tasks.push(longQuery());
}
const tasksPromise = Promise.all(tasks);
await assert.rejects(async () => {
await longQuery();
}, /get connection timeout after/);
await tasksPromise;
});

it('should release conn to pool', async () => {
const tasks: Array<Promise<void>> = [];
const timeoutTasks: Array<Promise<void>> = [];
// 1. fill the pool
for (let i = 0; i < 10; i++) {
tasks.push(longQuery());
}
// 2. add more conn and wait for timeout
for (let i = 0; i < 10; i++) {
timeoutTasks.push(longQuery());
}
const [ succeedTasks, failedTasks ] = await Promise.all([
Promise.allSettled(tasks),
Promise.allSettled(timeoutTasks),
]);
const succeedCount = succeedTasks.filter(t => t.status === 'fulfilled').length;
assert.equal(succeedCount, 10);

const failedCount = failedTasks.filter(t => t.status === 'rejected').length;
assert.equal(failedCount, 10);

// 3. after pool empty, create new tasks
const retryTasks: Array<Promise<void>> = [];
for (let i = 0; i < 10; i++) {
retryTasks.push(longQuery());
}
await Promise.all(retryTasks);
});

it('should not wait too long', async () => {
const tasks: Array<Promise<void>> = [];
const timeoutTasks: Array<Promise<void>> = [];
const fastTasks: Array<Promise<void>> = [];
const start = performance.now();
// 1. fill the pool
for (let i = 0; i < 10; i++) {
tasks.push(longQuery());
}
const tasksPromise = Promise.allSettled(tasks);
// 2. add more conn and wait for timeout
for (let i = 0; i < 10; i++) {
timeoutTasks.push(longQuery());
}
const timeoutTasksPromise = Promise.allSettled(timeoutTasks);
await setTimeout(600);
// 3. add fast query
for (let i = 0; i < 10; i++) {
fastTasks.push(longQuery(1));
}
const fastTasksPromise = Promise.allSettled(fastTasks);
const [ succeedTasks, failedTasks, fastTaskResults ] = await Promise.all([
tasksPromise,
timeoutTasksPromise,
fastTasksPromise,
]);
const duration = performance.now() - start;
const succeedCount = succeedTasks.filter(t => t.status === 'fulfilled').length;
assert.equal(succeedCount, 10);

const failedCount = failedTasks.filter(t => t.status === 'rejected').length;
assert.equal(failedCount, 10);

const faskTaskSucceedCount = fastTaskResults.filter(t => t.status === 'fulfilled').length;
assert.equal(faskTaskSucceedCount, 10);

// - 10 long queries cost 1000ms
// - 10 timeout queries should be timeout in long query execution so not cost time
// - 10 fast queries wait long query to finish, cost 1ms
// 1000ms + 0ms + 1ms < 1100ms
assert(duration < 1100);
});

});
});
Loading