Skip to content

Commit

Permalink
feat: add aio_prefetch
Browse files Browse the repository at this point in the history
  • Loading branch information
kalombos committed Jun 22, 2024
1 parent 9b8a59b commit f7d977d
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 3 deletions.
37 changes: 36 additions & 1 deletion peewee_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
'PooledMySQLDatabase',
'Transaction',
'AioModel',
'aio_prefetch'

# Compatibility API (deprecated in v1.0 release)
'Manager',
Expand Down Expand Up @@ -655,6 +656,38 @@ def init(self, database, **kwargs):
register_database(PooledMySQLDatabase, 'mysql+pool+async')


async def aio_prefetch(sq, *subqueries, prefetch_type):
"""Asynchronous version of the `prefetch()` from peewee."""
if not subqueries:
return sq

fixed_queries = peewee.prefetch_add_subquery(sq, subqueries, prefetch_type)
deps = {}
rel_map = {}

for pq in reversed(fixed_queries):
query_model = pq.model
if pq.fields:
for rel_model in pq.rel_models:
rel_map.setdefault(rel_model, [])
rel_map[rel_model].append(pq)

deps[query_model] = {}
id_map = deps[query_model]
has_relations = bool(rel_map.get(query_model))

result = await pq.query.aio_execute()

for instance in result:
if pq.fields:
pq.store_instance(instance, id_map)
if has_relations:
for rel in rel_map[query_model]:
rel.populate_instance(instance, deps[rel.model])

return result


class AioQueryMixin:
@peewee.database_required
async def aio_execute(self, database):
Expand Down Expand Up @@ -746,6 +779,9 @@ async def aio_exists(self, database):
clone._offset = None
return bool(await clone.aio_scalar())

def aio_prefetch(self, *subqueries, **kwargs):
return aio_prefetch(self, *subqueries, **kwargs)


class AioSelect(peewee.Select, AioSelectMixin):
pass
Expand Down Expand Up @@ -810,7 +846,6 @@ async def aio_delete_instance(self, recursive=False, delete_nullable=False):
if recursive:
dependencies = self.dependencies(delete_nullable)
for query, fk in reversed(list(dependencies)):
print(query, fk)
model = fk.model
if fk.null and not delete_nullable:
await model.update(**{fk.name: None}).where(query).aio_execute()
Expand Down
4 changes: 4 additions & 0 deletions peewee_async_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ async def count(query, clear_limit=False):

async def prefetch(sq, *subqueries, prefetch_type):
"""Asynchronous version of the `prefetch()` from peewee."""
warnings.warn(
"`prefetch` method is deprecated, use `AioModel.aio_prefetch` or aio_prefetch instead.",
DeprecationWarning
)
database = _query_db(sq)
if not subqueries:
result = await database.aio_execute(sq)
Expand Down
37 changes: 36 additions & 1 deletion tests/aio_model/test_shortcuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from peewee import fn

from tests.conftest import dbs_all
from tests.models import TestModel, IntegerTestModel, TestModelAlpha, TestModelBeta
from tests.models import TestModel, IntegerTestModel, TestModelAlpha, TestModelBeta, TestModelGamma


@dbs_all
Expand Down Expand Up @@ -133,3 +133,38 @@ async def test_aio_exists(db):

assert await TestModel.select().where(TestModel.data=="data").aio_exists() is True
assert await TestModel.select().where(TestModel.data == "not_existed").aio_exists() is False


@dbs_all
@pytest.mark.parametrize(
"prefetch_type",
peewee.PREFETCH_TYPE.values()
)
async def test_aio_prefetch(db, prefetch_type):
alpha_1 = await TestModelAlpha.aio_create(text='Alpha 1')
alpha_2 = await TestModelAlpha.aio_create(text='Alpha 2')

beta_11 = await TestModelBeta.aio_create(alpha=alpha_1, text='Beta 11')
beta_12 = await TestModelBeta.aio_create(alpha=alpha_1, text='Beta 12')
_ = await TestModelBeta.aio_create(
alpha=alpha_2, text='Beta 21'
)
_ = await TestModelBeta.aio_create(
alpha=alpha_2, text='Beta 22'
)

gamma_111 = await TestModelGamma.aio_create(
beta=beta_11, text='Gamma 111'
)
gamma_112 = await TestModelGamma.aio_create(
beta=beta_11, text='Gamma 112'
)

result = await TestModelAlpha.select().order_by(TestModelAlpha.id).aio_prefetch(
TestModelBeta.select().order_by(TestModelBeta.id),
TestModelGamma.select().order_by(TestModelGamma.id),
prefetch_type=prefetch_type,
)
assert tuple(result) == (alpha_1, alpha_2)
assert tuple(result[0].betas) == (beta_11, beta_12)
assert tuple(result[0].betas[0].gammas) == (gamma_111, gamma_112)
3 changes: 2 additions & 1 deletion tests/test_shortcuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import peewee
import pytest

import peewee_async
from tests.conftest import manager_for_all_dbs
from tests.models import TestModel, TestModelAlpha, TestModelBeta, TestModelGamma

Expand Down Expand Up @@ -32,7 +33,7 @@ async def test_prefetch(manager, prefetch_type):
gamma_112 = await manager.create(
TestModelGamma, beta=beta_11, text='Gamma 112')

result = await manager.prefetch(
result = await peewee_async.prefetch(
TestModelAlpha.select().order_by(TestModelAlpha.id),
TestModelBeta.select().order_by(TestModelBeta.id),
TestModelGamma.select().order_by(TestModelGamma.id),
Expand Down

0 comments on commit f7d977d

Please sign in to comment.