From f753febb01a10573471be552eecdbdf397f007cc Mon Sep 17 00:00:00 2001 From: kalombo Date: Sat, 22 Jun 2024 18:44:31 +0500 Subject: [PATCH] feat: add aio_prefetch --- peewee_async.py | 37 ++++++++++++++++++++++++++++++- peewee_async_compat.py | 4 ++++ tests/aio_model/test_shortcuts.py | 37 ++++++++++++++++++++++++++++++- tests/test_shortcuts.py | 3 ++- 4 files changed, 78 insertions(+), 3 deletions(-) diff --git a/peewee_async.py b/peewee_async.py index eb85799..28f6977 100644 --- a/peewee_async.py +++ b/peewee_async.py @@ -59,6 +59,7 @@ 'PooledMySQLDatabase', 'Transaction', 'AioModel', + 'aio_prefetch' # Compatibility API (deprecated in v1.0 release) 'Manager', @@ -655,6 +656,38 @@ def init(self, database, **kwargs): register_database(PooledMySQLDatabase, 'mysql+pool+async') +async def aio_prefetch(sq, *subqueries, prefetch_type): + """Asynchronous version of the `prefetch()` from peewee.""" + if not subqueries: + return sq + + fixed_queries = peewee.prefetch_add_subquery(sq, subqueries, prefetch_type) + deps = {} + rel_map = {} + + for pq in reversed(fixed_queries): + query_model = pq.model + if pq.fields: + for rel_model in pq.rel_models: + rel_map.setdefault(rel_model, []) + rel_map[rel_model].append(pq) + + deps[query_model] = {} + id_map = deps[query_model] + has_relations = bool(rel_map.get(query_model)) + + result = await pq.query.aio_execute() + + for instance in result: + if pq.fields: + pq.store_instance(instance, id_map) + if has_relations: + for rel in rel_map[query_model]: + rel.populate_instance(instance, deps[rel.model]) + + return result + + class AioQueryMixin: @peewee.database_required async def aio_execute(self, database): @@ -746,6 +779,9 @@ async def aio_exists(self, database): clone._offset = None return bool(await clone.aio_scalar()) + def aio_prefetch(self, *subqueries, **kwargs): + return aio_prefetch(self, *subqueries, **kwargs) + class AioSelect(peewee.Select, AioSelectMixin): pass @@ -810,7 +846,6 @@ async def aio_delete_instance(self, recursive=False, delete_nullable=False): if recursive: dependencies = self.dependencies(delete_nullable) for query, fk in reversed(list(dependencies)): - print(query, fk) model = fk.model if fk.null and not delete_nullable: await model.update(**{fk.name: None}).where(query).aio_execute() diff --git a/peewee_async_compat.py b/peewee_async_compat.py index d7776ac..9171ea9 100644 --- a/peewee_async_compat.py +++ b/peewee_async_compat.py @@ -106,6 +106,10 @@ async def count(query, clear_limit=False): async def prefetch(sq, *subqueries, prefetch_type): """Asynchronous version of the `prefetch()` from peewee.""" + warnings.warn( + "`prefetch` method is deprecated, use `AioModel.aio_prefetch` or aio_prefetch instead.", + DeprecationWarning + ) database = _query_db(sq) if not subqueries: result = await database.aio_execute(sq) diff --git a/tests/aio_model/test_shortcuts.py b/tests/aio_model/test_shortcuts.py index 0a4bbab..f022e74 100644 --- a/tests/aio_model/test_shortcuts.py +++ b/tests/aio_model/test_shortcuts.py @@ -5,7 +5,7 @@ from peewee import fn from tests.conftest import dbs_all -from tests.models import TestModel, IntegerTestModel, TestModelAlpha, TestModelBeta +from tests.models import TestModel, IntegerTestModel, TestModelAlpha, TestModelBeta, TestModelGamma @dbs_all @@ -133,3 +133,38 @@ async def test_aio_exists(db): assert await TestModel.select().where(TestModel.data=="data").aio_exists() is True assert await TestModel.select().where(TestModel.data == "not_existed").aio_exists() is False + + +@dbs_all +@pytest.mark.parametrize( + "prefetch_type", + peewee.PREFETCH_TYPE.values() +) +async def test_aio_prefetch(db, prefetch_type): + alpha_1 = await TestModelAlpha.aio_create(text='Alpha 1') + alpha_2 = await TestModelAlpha.aio_create(text='Alpha 2') + + beta_11 = await TestModelBeta.aio_create(alpha=alpha_1, text='Beta 11') + beta_12 = await TestModelBeta.aio_create(alpha=alpha_1, text='Beta 12') + _ = await TestModelBeta.aio_create( + alpha=alpha_2, text='Beta 21' + ) + _ = await TestModelBeta.aio_create( + alpha=alpha_2, text='Beta 22' + ) + + gamma_111 = await TestModelGamma.aio_create( + beta=beta_11, text='Gamma 111' + ) + gamma_112 = await TestModelGamma.aio_create( + beta=beta_11, text='Gamma 112' + ) + + result = await TestModelAlpha.select().order_by(TestModelAlpha.id).aio_prefetch( + TestModelBeta.select().order_by(TestModelBeta.id), + TestModelGamma.select().order_by(TestModelGamma.id), + prefetch_type=prefetch_type, + ) + assert tuple(result) == (alpha_1, alpha_2) + assert tuple(result[0].betas) == (beta_11, beta_12) + assert tuple(result[0].betas[0].gammas) == (gamma_111, gamma_112) diff --git a/tests/test_shortcuts.py b/tests/test_shortcuts.py index 48ff8a7..736b58e 100644 --- a/tests/test_shortcuts.py +++ b/tests/test_shortcuts.py @@ -3,6 +3,7 @@ import peewee import pytest +import peewee_async from tests.conftest import manager_for_all_dbs from tests.models import TestModel, TestModelAlpha, TestModelBeta, TestModelGamma @@ -32,7 +33,7 @@ async def test_prefetch(manager, prefetch_type): gamma_112 = await manager.create( TestModelGamma, beta=beta_11, text='Gamma 112') - result = await manager.prefetch( + result = await peewee_async.prefetch( TestModelAlpha.select().order_by(TestModelAlpha.id), TestModelBeta.select().order_by(TestModelBeta.id), TestModelGamma.select().order_by(TestModelGamma.id),