Skip to content

Commit

Permalink
chore: Refactor Async(Postgresql/MySQL)Connection
Browse files Browse the repository at this point in the history
  • Loading branch information
akerlay committed Apr 3, 2024
1 parent 9e3dda8 commit 1c55d89
Showing 1 changed file with 42 additions and 63 deletions.
105 changes: 42 additions & 63 deletions peewee_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Copyright (c) 2014, Alexey Kinëv <rudy@05bit.com>
"""
import abc
import asyncio
import contextlib
import functools
Expand Down Expand Up @@ -851,19 +852,14 @@ async def aio_execute(self, query):
return (await coroutine(query))


##############
# PostgreSQL #
##############


class AsyncPostgresqlConnection:
class AsyncConnectionPool(metaclass=abc.ABCMeta):
"""Asynchronous database connection pool.
"""
def __init__(self, *, database=None, loop=None, timeout=None, **kwargs):
self.pool = None
self.loop = loop
self.database = database
self.timeout = timeout or aiopg.DEFAULT_TIMEOUT
self.timeout = timeout
self.connect_params = kwargs

async def acquire(self):
Expand All @@ -876,14 +872,11 @@ def release(self, conn):
"""
self.pool.release(conn)

@abc.abstractmethod
async def connect(self):
"""Create connection pool asynchronously.
"""
self.pool = await aiopg.create_pool(
loop=self.loop,
timeout=self.timeout,
database=self.database,
**self.connect_params)
raise NotImplementedError

async def close(self):
"""Terminate all pool connections.
Expand All @@ -892,8 +885,7 @@ async def close(self):
await self.pool.wait_closed()

async def cursor(self, conn=None, *args, **kwargs):
"""Get a cursor for the specified transaction connection
or acquire from the pool.
"""Get cursor for connection from pool.
"""
in_transaction = conn is not None
if not conn:
Expand All @@ -909,6 +901,41 @@ async def cursor(self, conn=None, *args, **kwargs):
in_transaction=in_transaction)
return cursor

async def release_cursor(self, cursor, in_transaction=False):
"""Release cursor coroutine. Unless in transaction,
the connection is also released back to the pool.
"""
conn = cursor.connection
await cursor.close()
if not in_transaction:
self.release(conn)


##############
# PostgreSQL #
##############


class AsyncPostgresqlConnection(AsyncConnectionPool):
"""Asynchronous database connection pool.
"""
def __init__(self, *, database=None, loop=None, timeout=None, **kwargs):
super().__init__(
database=database,
loop=loop,
timeout=timeout or aiopg.DEFAULT_TIMEOUT,
**kwargs,
)

async def connect(self):
"""Create connection pool asynchronously.
"""
self.pool = await aiopg.create_pool(
loop=self.loop,
timeout=self.timeout,
database=self.database,
**self.connect_params)

async def release_cursor(self, cursor, in_transaction=False):
"""Release cursor coroutine. Unless in transaction,
the connection is also released back to the pool.
Expand Down Expand Up @@ -1027,25 +1054,9 @@ def use_speedups(self, value):
#########


class AsyncMySQLConnection:
class AsyncMySQLConnection(AsyncConnectionPool):
"""Asynchronous database connection pool.
"""
def __init__(self, *, database=None, loop=None, timeout=None, **kwargs):
self.pool = None
self.loop = loop
self.database = database
self.timeout = timeout
self.connect_params = kwargs

async def acquire(self):
"""Acquire connection from pool.
"""
return (await self.pool.acquire())

def release(self, conn):
"""Release connection to pool.
"""
self.pool.release(conn)

async def connect(self):
"""Create connection pool asynchronously.
Expand All @@ -1056,38 +1067,6 @@ async def connect(self):
connect_timeout=self.timeout,
**self.connect_params)

async def close(self):
"""Terminate all pool connections.
"""
self.pool.terminate()
await self.pool.wait_closed()

async def cursor(self, conn=None, *args, **kwargs):
"""Get cursor for connection from pool.
"""
in_transaction = conn is not None
if not conn:
conn = await self.acquire()
try:
cursor = await conn.cursor(*args, **kwargs)
except:
if not in_transaction:
self.release(conn)
raise
cursor.release = functools.partial(
self.release_cursor, cursor,
in_transaction=in_transaction)
return cursor

async def release_cursor(self, cursor, in_transaction=False):
"""Release cursor coroutine. Unless in transaction,
the connection is also released back to the pool.
"""
conn = cursor.connection
await cursor.close()
if not in_transaction:
self.release(conn)


class MySQLDatabase(AsyncDatabase, peewee.MySQLDatabase):
"""MySQL database driver providing **single drop-in sync** connection
Expand Down

0 comments on commit 1c55d89

Please sign in to comment.