diff --git a/peewee_async/pool.py b/peewee_async/pool.py index e8586bb..e3988d4 100644 --- a/peewee_async/pool.py +++ b/peewee_async/pool.py @@ -2,7 +2,7 @@ import asyncio from typing import Any, Optional, cast -from .utils import aiopg, aiomysql, PoolProtocol, ConnectionProtocol, format_dsn, psycopg, psycopg_pool +from .utils import aiopg, aiomysql, ConnectionProtocol, format_dsn, psycopg, psycopg_pool class PoolBackend(metaclass=abc.ABCMeta): @@ -10,7 +10,7 @@ class PoolBackend(metaclass=abc.ABCMeta): """ def __init__(self, *, database: str, **kwargs: Any) -> None: - self.pool: Optional[PoolProtocol] = None + self.pool: Optional[Any] = None self.database = database self.connect_params = kwargs self._connection_lock = asyncio.Lock() @@ -23,11 +23,13 @@ def is_connected(self) -> bool: @property def min_size(self) -> int: - return self.pool.minsize + assert self.pool is not None, "Pool is not connected" + return cast(int, self.pool.minsize) @property def max_size(self) -> int: - return self.pool.maxsize + assert self.pool is not None, "Pool is not connected" + return cast(int, self.pool.maxsize) def has_acquired_connections(self) -> bool: if self.pool is not None: @@ -45,7 +47,7 @@ async def acquire(self) -> ConnectionProtocol: if self.pool is None: await self.connect() assert self.pool is not None, "Pool is not connected" - return await self.pool.acquire() + return cast(ConnectionProtocol, await self.pool.acquire()) async def release(self, conn: ConnectionProtocol) -> None: """Release connection to pool. @@ -76,12 +78,9 @@ async def create(self) -> None: """ if "connect_timeout" in self.connect_params: self.connect_params['timeout'] = self.connect_params.pop("connect_timeout") - self.pool = cast( - PoolProtocol, - await aiopg.create_pool( - database=self.database, - **self.connect_params - ) + self.pool = await aiopg.create_pool( + database=self.database, + **self.connect_params ) @@ -117,7 +116,7 @@ async def create(self) -> None: def has_acquired_connections(self) -> bool: if self.pool is not None: - return self.pool._nconns - self.pool._num_pool > 0 + return bool(self.pool.nconns - self.pool._num_pool > 0) return False async def acquire(self) -> ConnectionProtocol: @@ -126,7 +125,7 @@ async def acquire(self) -> ConnectionProtocol: if self.pool is None: await self.connect() assert self.pool is not None, "Pool is not connected" - return await self.pool.getconn() + return cast(ConnectionProtocol, await self.pool.getconn()) async def release(self, conn: ConnectionProtocol) -> None: """Release connection to pool. @@ -142,11 +141,13 @@ async def terminate(self) -> None: @property def min_size(self) -> int: - return self.pool.min_size + assert self.pool is not None, "Pool is not connected" + return cast(int, self.pool.min_size) @property def max_size(self) -> int: - return self.pool.max_size + assert self.pool is not None, "Pool is not connected" + return cast(int, self.pool.max_size) class MysqlPoolBackend(PoolBackend): @@ -156,9 +157,6 @@ class MysqlPoolBackend(PoolBackend): async def create(self) -> None: """Create connection pool asynchronously. """ - self.pool = cast( - PoolProtocol, - await aiomysql.create_pool( - db=self.database, **self.connect_params - ), + self.pool = await aiomysql.create_pool( + db=self.database, **self.connect_params ) diff --git a/peewee_async/utils.py b/peewee_async/utils.py index bcf3d87..1d17da1 100644 --- a/peewee_async/utils.py +++ b/peewee_async/utils.py @@ -57,27 +57,6 @@ def cursor( ... -class PoolProtocol(Protocol): - - _used: Set[ConnectionProtocol] - - @property - def closed(self) -> bool: - ... - - async def acquire(self) -> ConnectionProtocol: - ... - - def release(self, conn: ConnectionProtocol) -> None: - ... - - def terminate(self) -> None: - ... - - async def wait_closed(self) -> None: - ... - - FetchResults = Callable[[CursorProtocol], Awaitable[Any]] diff --git a/tests/test_database.py b/tests/test_database.py index cd862bd..aa2250d 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -86,8 +86,8 @@ async def test_connections_param(db_name: str) -> None: database = db_cls(**default_params) await database.aio_connect() - assert database.pool_backend.min_size == 2 # type: ignore - assert database.pool_backend.max_size == 3 # type: ignore + assert database.pool_backend.min_size == 2 + assert database.pool_backend.max_size == 3 await database.aio_close()