Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Refactor Async(Postgresql/MySQL)Connection #213

Merged
merged 2 commits into from
Apr 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 50 additions & 69 deletions peewee_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Copyright (c) 2014, Alexey Kinëv <rudy@05bit.com>

"""
import abc
import asyncio
import contextlib
import functools
Expand Down Expand Up @@ -712,7 +713,7 @@ async def connect_async(self, loop=None, timeout=None):
timeout=timeout,
**self.connect_params_async
)
await conn.connect()
await conn.create()
self._async_conn = conn

async def cursor_async(self):
Expand All @@ -734,7 +735,7 @@ async def close_async(self):
if self._async_conn:
conn = self._async_conn
self._async_conn = None
await conn.close()
await conn.terminate()

async def push_transaction_async(self):
"""Increment async transaction depth.
Expand Down Expand Up @@ -851,19 +852,14 @@ async def aio_execute(self, query):
return (await coroutine(query))


##############
# PostgreSQL #
##############


class AsyncPostgresqlConnection:
class AioPool(metaclass=abc.ABCMeta):
"""Asynchronous database connection pool.
"""
def __init__(self, *, database=None, loop=None, timeout=None, **kwargs):
self.pool = None
self.loop = loop
self.database = database
self.timeout = timeout or aiopg.DEFAULT_TIMEOUT
self.timeout = timeout
self.connect_params = kwargs

async def acquire(self):
Expand All @@ -876,24 +872,20 @@ def release(self, conn):
"""
self.pool.release(conn)

async def connect(self):
@abc.abstractmethod
async def create(self):
"""Create connection pool asynchronously.
"""
self.pool = await aiopg.create_pool(
loop=self.loop,
timeout=self.timeout,
database=self.database,
**self.connect_params)
raise NotImplementedError

async def close(self):
async def terminate(self):
"""Terminate all pool connections.
"""
self.pool.terminate()
await self.pool.wait_closed()

async def cursor(self, conn=None, *args, **kwargs):
"""Get a cursor for the specified transaction connection
or acquire from the pool.
"""Get cursor for connection from pool.
"""
in_transaction = conn is not None
if not conn:
Expand All @@ -914,10 +906,44 @@ async def release_cursor(self, cursor, in_transaction=False):
the connection is also released back to the pool.
"""
conn = cursor.connection
cursor.close()
await self.close_cursor(cursor)
if not in_transaction:
self.release(conn)

@abc.abstractmethod
async def close_cursor(self, cursor):
raise NotImplementedError



##############
# PostgreSQL #
##############


class AioPostgresqlPool(AioPool):
"""Asynchronous database connection pool.
"""
def __init__(self, *, database=None, loop=None, timeout=None, **kwargs):
super().__init__(
database=database,
loop=loop,
timeout=timeout or aiopg.DEFAULT_TIMEOUT,
**kwargs,
)

async def create(self):
"""Create connection pool asynchronously.
"""
self.pool = await aiopg.create_pool(
loop=self.loop,
timeout=self.timeout,
database=self.database,
**self.connect_params)

async def close_cursor(self, cursor):
cursor.close()


class AsyncPostgresqlMixin(AsyncDatabase):
"""Mixin for `peewee.PostgresqlDatabase` providing extra methods
Expand All @@ -926,7 +952,7 @@ class AsyncPostgresqlMixin(AsyncDatabase):
if psycopg2:
Error = psycopg2.Error

def init_async(self, conn_cls=AsyncPostgresqlConnection,
def init_async(self, conn_cls=AioPostgresqlPool,
enable_json=False, enable_hstore=False):
if not aiopg:
raise Exception("Error, aiopg is not installed!")
Expand Down Expand Up @@ -1027,27 +1053,11 @@ def use_speedups(self, value):
#########


class AsyncMySQLConnection:
class AioMysqlPool(AioPool):
"""Asynchronous database connection pool.
"""
def __init__(self, *, database=None, loop=None, timeout=None, **kwargs):
self.pool = None
self.loop = loop
self.database = database
self.timeout = timeout
self.connect_params = kwargs

async def acquire(self):
"""Acquire connection from pool.
"""
return (await self.pool.acquire())

def release(self, conn):
"""Release connection to pool.
"""
self.pool.release(conn)

async def connect(self):
async def create(self):
"""Create connection pool asynchronously.
"""
self.pool = await aiomysql.create_pool(
Expand All @@ -1056,37 +1066,8 @@ async def connect(self):
connect_timeout=self.timeout,
**self.connect_params)

async def close(self):
"""Terminate all pool connections.
"""
self.pool.terminate()
await self.pool.wait_closed()

async def cursor(self, conn=None, *args, **kwargs):
"""Get cursor for connection from pool.
"""
in_transaction = conn is not None
if not conn:
conn = await self.acquire()
try:
cursor = await conn.cursor(*args, **kwargs)
except:
if not in_transaction:
self.release(conn)
raise
cursor.release = functools.partial(
self.release_cursor, cursor,
in_transaction=in_transaction)
return cursor

async def release_cursor(self, cursor, in_transaction=False):
"""Release cursor coroutine. Unless in transaction,
the connection is also released back to the pool.
"""
conn = cursor.connection
async def close_cursor(self, cursor):
await cursor.close()
if not in_transaction:
self.release(conn)


class MySQLDatabase(AsyncDatabase, peewee.MySQLDatabase):
Expand All @@ -1108,7 +1089,7 @@ def init(self, database, **kwargs):
raise Exception("Error, aiomysql is not installed!")
self.min_connections = 1
self.max_connections = 1
self._async_conn_cls = kwargs.pop('async_conn', AsyncMySQLConnection)
self._async_conn_cls = kwargs.pop('async_conn', AioMysqlPool)
super().init(database, **kwargs)

@property
Expand Down
Loading