Skip to content

Commit

Permalink
rename models, change prefixes from async to aio, add abstract method…
Browse files Browse the repository at this point in the history
… for close cursor
  • Loading branch information
akerlay committed Apr 8, 2024
1 parent 1c55d89 commit 05d9ca5
Showing 1 changed file with 21 additions and 19 deletions.
40 changes: 21 additions & 19 deletions peewee_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,7 @@ async def connect_async(self, loop=None, timeout=None):
timeout=timeout,
**self.connect_params_async
)
await conn.connect()
await conn.create()
self._async_conn = conn

async def cursor_async(self):
Expand All @@ -735,7 +735,7 @@ async def close_async(self):
if self._async_conn:
conn = self._async_conn
self._async_conn = None
await conn.close()
await conn.terminate()

async def push_transaction_async(self):
"""Increment async transaction depth.
Expand Down Expand Up @@ -852,7 +852,7 @@ async def aio_execute(self, query):
return (await coroutine(query))


class AsyncConnectionPool(metaclass=abc.ABCMeta):
class AioPool(metaclass=abc.ABCMeta):
"""Asynchronous database connection pool.
"""
def __init__(self, *, database=None, loop=None, timeout=None, **kwargs):
Expand All @@ -873,12 +873,12 @@ def release(self, conn):
self.pool.release(conn)

@abc.abstractmethod
async def connect(self):
async def create(self):
"""Create connection pool asynchronously.
"""
raise NotImplementedError

async def close(self):
async def terminate(self):
"""Terminate all pool connections.
"""
self.pool.terminate()
Expand Down Expand Up @@ -906,17 +906,22 @@ async def release_cursor(self, cursor, in_transaction=False):
the connection is also released back to the pool.
"""
conn = cursor.connection
await cursor.close()
await self.close_cursor(cursor)
if not in_transaction:
self.release(conn)

@abc.abstractmethod
async def close_cursor(self, cursor):
raise NotImplementedError



##############
# PostgreSQL #
##############


class AsyncPostgresqlConnection(AsyncConnectionPool):
class AioPostgresqlPool(AioPool):
"""Asynchronous database connection pool.
"""
def __init__(self, *, database=None, loop=None, timeout=None, **kwargs):
Expand All @@ -927,7 +932,7 @@ def __init__(self, *, database=None, loop=None, timeout=None, **kwargs):
**kwargs,
)

async def connect(self):
async def create(self):
"""Create connection pool asynchronously.
"""
self.pool = await aiopg.create_pool(
Expand All @@ -936,14 +941,8 @@ async def connect(self):
database=self.database,
**self.connect_params)

async def release_cursor(self, cursor, in_transaction=False):
"""Release cursor coroutine. Unless in transaction,
the connection is also released back to the pool.
"""
conn = cursor.connection
async def close_cursor(self, cursor):
cursor.close()
if not in_transaction:
self.release(conn)


class AsyncPostgresqlMixin(AsyncDatabase):
Expand All @@ -953,7 +952,7 @@ class AsyncPostgresqlMixin(AsyncDatabase):
if psycopg2:
Error = psycopg2.Error

def init_async(self, conn_cls=AsyncPostgresqlConnection,
def init_async(self, conn_cls=AioPostgresqlPool,
enable_json=False, enable_hstore=False):
if not aiopg:
raise Exception("Error, aiopg is not installed!")
Expand Down Expand Up @@ -1054,11 +1053,11 @@ def use_speedups(self, value):
#########


class AsyncMySQLConnection(AsyncConnectionPool):
class AioMysqlPool(AioPool):
"""Asynchronous database connection pool.
"""

async def connect(self):
async def create(self):
"""Create connection pool asynchronously.
"""
self.pool = await aiomysql.create_pool(
Expand All @@ -1067,6 +1066,9 @@ async def connect(self):
connect_timeout=self.timeout,
**self.connect_params)

async def close_cursor(self, cursor):
await cursor.close()


class MySQLDatabase(AsyncDatabase, peewee.MySQLDatabase):
"""MySQL database driver providing **single drop-in sync** connection
Expand All @@ -1087,7 +1089,7 @@ def init(self, database, **kwargs):
raise Exception("Error, aiomysql is not installed!")
self.min_connections = 1
self.max_connections = 1
self._async_conn_cls = kwargs.pop('async_conn', AsyncMySQLConnection)
self._async_conn_cls = kwargs.pop('async_conn', AioMysqlPool)
super().init(database, **kwargs)

@property
Expand Down

0 comments on commit 05d9ca5

Please sign in to comment.