Skip to content

Commit

Permalink
fix typing
Browse files Browse the repository at this point in the history
  • Loading branch information
kalombos committed Oct 11, 2024
1 parent 4e3ca11 commit cc08a31
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 43 deletions.
38 changes: 18 additions & 20 deletions peewee_async/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
import asyncio
from typing import Any, Optional, cast

from .utils import aiopg, aiomysql, PoolProtocol, ConnectionProtocol, format_dsn, psycopg, psycopg_pool
from .utils import aiopg, aiomysql, ConnectionProtocol, format_dsn, psycopg, psycopg_pool


class PoolBackend(metaclass=abc.ABCMeta):
"""Asynchronous database connection pool.
"""

def __init__(self, *, database: str, **kwargs: Any) -> None:
self.pool: Optional[PoolProtocol] = None
self.pool: Optional[Any] = None
self.database = database
self.connect_params = kwargs
self._connection_lock = asyncio.Lock()
Expand All @@ -23,11 +23,13 @@ def is_connected(self) -> bool:

@property
def min_size(self) -> int:
return self.pool.minsize
assert self.pool is not None, "Pool is not connected"
return cast(int, self.pool.minsize)

@property
def max_size(self) -> int:
return self.pool.maxsize
assert self.pool is not None, "Pool is not connected"
return cast(int, self.pool.maxsize)

def has_acquired_connections(self) -> bool:
if self.pool is not None:
Expand All @@ -45,7 +47,7 @@ async def acquire(self) -> ConnectionProtocol:
if self.pool is None:
await self.connect()
assert self.pool is not None, "Pool is not connected"
return await self.pool.acquire()
return cast(ConnectionProtocol, await self.pool.acquire())

async def release(self, conn: ConnectionProtocol) -> None:
"""Release connection to pool.
Expand Down Expand Up @@ -76,12 +78,9 @@ async def create(self) -> None:
"""
if "connect_timeout" in self.connect_params:
self.connect_params['timeout'] = self.connect_params.pop("connect_timeout")
self.pool = cast(
PoolProtocol,
await aiopg.create_pool(
database=self.database,
**self.connect_params
)
self.pool = await aiopg.create_pool(
database=self.database,
**self.connect_params
)


Expand Down Expand Up @@ -117,7 +116,7 @@ async def create(self) -> None:

def has_acquired_connections(self) -> bool:
if self.pool is not None:
return self.pool._nconns - self.pool._num_pool > 0
return bool(self.pool.nconns - self.pool._num_pool > 0)
return False

async def acquire(self) -> ConnectionProtocol:
Expand All @@ -126,7 +125,7 @@ async def acquire(self) -> ConnectionProtocol:
if self.pool is None:
await self.connect()
assert self.pool is not None, "Pool is not connected"
return await self.pool.getconn()
return cast(ConnectionProtocol, await self.pool.getconn())

async def release(self, conn: ConnectionProtocol) -> None:
"""Release connection to pool.
Expand All @@ -142,11 +141,13 @@ async def terminate(self) -> None:

@property
def min_size(self) -> int:
return self.pool.min_size
assert self.pool is not None, "Pool is not connected"
return cast(int, self.pool.min_size)

@property
def max_size(self) -> int:
return self.pool.max_size
assert self.pool is not None, "Pool is not connected"
return cast(int, self.pool.max_size)


class MysqlPoolBackend(PoolBackend):
Expand All @@ -156,9 +157,6 @@ class MysqlPoolBackend(PoolBackend):
async def create(self) -> None:
"""Create connection pool asynchronously.
"""
self.pool = cast(
PoolProtocol,
await aiomysql.create_pool(
db=self.database, **self.connect_params
),
self.pool = await aiomysql.create_pool(
db=self.database, **self.connect_params
)
21 changes: 0 additions & 21 deletions peewee_async/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,27 +57,6 @@ def cursor(
...


class PoolProtocol(Protocol):

_used: Set[ConnectionProtocol]

@property
def closed(self) -> bool:
...

async def acquire(self) -> ConnectionProtocol:
...

def release(self, conn: ConnectionProtocol) -> None:
...

def terminate(self) -> None:
...

async def wait_closed(self) -> None:
...


FetchResults = Callable[[CursorProtocol], Awaitable[Any]]


Expand Down
4 changes: 2 additions & 2 deletions tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ async def test_connections_param(db_name: str) -> None:
database = db_cls(**default_params)
await database.aio_connect()

assert database.pool_backend.min_size == 2 # type: ignore
assert database.pool_backend.max_size == 3 # type: ignore
assert database.pool_backend.min_size == 2
assert database.pool_backend.max_size == 3

await database.aio_close()

Expand Down

0 comments on commit cc08a31

Please sign in to comment.