Skip to content

Commit

Permalink
added AioModel feature
Browse files Browse the repository at this point in the history
  • Loading branch information
kalombos committed Dec 29, 2023
1 parent dda6949 commit 11e8d76
Show file tree
Hide file tree
Showing 7 changed files with 304 additions and 6 deletions.
96 changes: 96 additions & 0 deletions peewee_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -1374,3 +1374,99 @@ def del_data(self, task):
"""Delete data for task from stored data dict.
"""
del self.data[id(task)]


class AioQueryMixin:
@peewee.database_required
async def aio_execute(self, database):
return await execute(self)


class AioModelDelete(peewee.ModelDelete, AioQueryMixin):
pass


class AioModelUpdate(peewee.ModelUpdate, AioQueryMixin):
pass


class AioModelInsert(peewee.ModelInsert, AioQueryMixin):
pass


class AioModelSelect(peewee.ModelSelect, AioQueryMixin):

async def aio_get(self, database=None):
clone = self.paginate(1, 1)
try:
return (await clone.aio_execute(database))[0]
except IndexError:
sql, params = clone.sql()
raise self.model.DoesNotExist('%s instance matching query does '
'not exist:\nSQL: %s\nParams: %s' %
(clone.model, sql, params))


class AioModel(peewee.Model):
"""
Implementation of most methods is copied from sync versions with replacement to async calls
"""

@classmethod
def select(cls, *fields):
is_default = not fields
if not fields:
fields = cls._meta.sorted_fields
return AioModelSelect(cls, fields, is_default=is_default)

@classmethod
def update(cls, __data=None, **update):
return AioModelUpdate(cls, cls._normalize_data(__data, update))

@classmethod
def insert(cls, __data=None, **insert):
return AioModelInsert(cls, cls._normalize_data(__data, insert))

@classmethod
def insert_many(cls, rows, fields=None):
return AioModelInsert(cls, insert=rows, columns=fields)

@classmethod
def insert_from(cls, query, fields):
columns = [getattr(cls, field) if isinstance(field, str)
else field for field in fields]
return AioModelInsert(cls, insert=query, columns=columns)

@classmethod
def delete(cls):
return AioModelDelete(cls)

@classmethod
async def aio_get(cls, *query, **filters):
sq = cls.select()
if query:
if len(query) == 1 and isinstance(query[0], int):
sq = sq.where(cls._meta.primary_key == query[0])
else:
sq = sq.where(*query)
if filters:
sq = sq.filter(**filters)
return await sq.aio_get()

@classmethod
async def aio_get_or_none(cls, *query, **filters):
try:
return await cls.aio_get(*query, **filters)
except cls.DoesNotExist:
return None

@classmethod
async def aio_create(cls, **data):
"""
the implementation is different from sync create method
"""
inst = cls(**data)
pk = await cls.insert(**dict(inst.__data__)).aio_execute()
if inst._pk is None:
inst._pk = pk
return inst
Empty file added tests/aio_model/__init__.py
Empty file.
47 changes: 47 additions & 0 deletions tests/aio_model/test_deleting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import uuid

from tests.conftest import postgres_only, all_dbs
from tests.models import TestModel
from tests.utils import model_has_fields


@all_dbs
async def test_delete__count(manager):
query = TestModel.insert_many([
{'text': "Test %s" % uuid.uuid4()},
{'text': "Test %s" % uuid.uuid4()},
])
await query.aio_execute()

count = await TestModel.delete().aio_execute()

assert count == 2


@all_dbs
async def test_delete__by_condition(manager):
expected_text = "text1"
deleted_text = "text2"
query = TestModel.insert_many([
{'text': expected_text},
{'text': deleted_text},
])
await query.aio_execute()

await TestModel.delete().where(TestModel.text == deleted_text).aio_execute()

res = await TestModel.select().aio_execute()
assert len(res) == 1
assert res[0].text == expected_text


@postgres_only
async def test_delete__return_model(manager):
m = await TestModel.aio_create(text="text", data="data")

res = await TestModel.delete().returning(TestModel).aio_execute()
assert model_has_fields(res[0], {
"id": m.id,
"text": m.text,
"data": m.data
}) is True
100 changes: 100 additions & 0 deletions tests/aio_model/test_inserting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import uuid

from tests.conftest import postgres_only, all_dbs
from tests.models import TestModel, UUIDTestModel
from tests.utils import model_has_fields


@all_dbs
async def test_insert_many(manager):
last_id = await TestModel.insert_many([
{'text': "Test %s" % uuid.uuid4()},
{'text': "Test %s" % uuid.uuid4()},
]).aio_execute()

res = await TestModel.select().aio_execute()

assert len(res) == 2
assert last_id in [m.id for m in res]


@all_dbs
async def test_insert__return_id(manager):
last_id = await TestModel.insert(text="Test %s" % uuid.uuid4()).aio_execute()

res = await TestModel.select().aio_execute()
obj = res[0]
assert last_id == obj.id


@postgres_only
async def test_insert_on_conflict_ignore__last_id_is_none(manager):
query = TestModel.insert(text="text").on_conflict_ignore()
await query.aio_execute()

last_id = await query.aio_execute()

assert last_id is None


@postgres_only
async def test_insert_on_conflict_ignore__return_model(manager):
query = TestModel.insert(text="text", data="data").on_conflict_ignore().returning(TestModel)

res = await query.aio_execute()

inserted = res[0]
res = await TestModel.select().aio_execute()
expected = res[0]

assert model_has_fields(inserted, {
"id": expected.id,
"text": expected.text,
"data": expected.data
}) is True


@postgres_only
async def test_insert_on_conflict_ignore__inserted_once(manager):
query = TestModel.insert(text="text").on_conflict_ignore()
last_id = await query.aio_execute()

await query.aio_execute()

res = await TestModel.select().aio_execute()
assert len(res) == 1
assert res[0].id == last_id


@postgres_only
async def test_insert__uuid_pk(manager):
query = UUIDTestModel.insert(text="Test %s" % uuid.uuid4())
last_id = await query.aio_execute()
assert len(str(last_id)) == 36


@postgres_only
async def test_insert__return_model(manager):
text = "Test %s" % uuid.uuid4()
data = "data"
query = TestModel.insert(text=text, data=data).returning(TestModel)

res = await query.aio_execute()

inserted = res[0]
assert model_has_fields(
inserted, {"id": inserted.id, "text": text, "data": data}
) is True


@postgres_only
async def test_insert_many__return_model(manager):
texts = [f"text{n}" for n in range(2)]
query = TestModel.insert_many([
{"text": text} for text in texts
]).returning(TestModel)

res = await query.aio_execute()

texts = [m.text for m in res]
assert sorted(texts) == ["text0", "text1"]
20 changes: 20 additions & 0 deletions tests/aio_model/test_selecting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import uuid

import pytest

from tests.conftest import all_dbs, postgres_only
from tests.models import TestModel, TestModelAlpha, TestModelBeta


@all_dbs
async def test__select__w_join(manager):
alpha = await TestModelAlpha.aio_create(text="Test 1")
beta = await TestModelBeta.aio_create(alpha_id=alpha.id, text="text")

result = (await TestModelBeta.select(TestModelBeta, TestModelAlpha).join(
TestModelAlpha,
attr="joined_alpha",
).aio_execute())[0]

assert result.id == beta.id
assert result.joined_alpha.id == alpha.id
33 changes: 33 additions & 0 deletions tests/aio_model/test_shortcuts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import uuid

import pytest

from tests.conftest import all_dbs, postgres_only
from tests.models import TestModel, TestModelAlpha, TestModelBeta



@all_dbs
async def test_aio_get(manager):
obj1 = await TestModel.aio_create(text="Test 1")
obj2 = await TestModel.aio_create(text="Test 2")

result = await TestModel.aio_get(TestModel.id == obj1.id)
assert result.id == obj1.id

result = await TestModel.aio_get(TestModel.text == "Test 2")
assert result.id == obj2.id

with pytest.raises(TestModel.DoesNotExist):
await TestModel.aio_get(TestModel.text == "unknown")


@all_dbs
async def test_aio_get_or_none(manager):
obj1 = await TestModel.aio_create(text="Test 1")

result = await TestModel.aio_get_or_none(TestModel.id == obj1.id)
assert result.id == obj1.id

result = await TestModel.aio_get_or_none(TestModel.text == "unknown")
assert result is None
14 changes: 8 additions & 6 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import peewee

from peewee_async import AioModel

class TestModel(peewee.Model):

class TestModel(AioModel):
__test__ = False # disable pytest warnings
text = peewee.CharField(max_length=100, unique=True)
data = peewee.TextField(default='')
Expand All @@ -12,15 +14,15 @@ def __str__(self):
return '<%s id=%s> %s' % (self.__class__.__name__, self.id, self.text)


class TestModelAlpha(peewee.Model):
class TestModelAlpha(AioModel):
__test__ = False
text = peewee.CharField()

def __str__(self):
return '<%s id=%s> %s' % (self.__class__.__name__, self.id, self.text)


class TestModelBeta(peewee.Model):
class TestModelBeta(AioModel):
__test__ = False
alpha = peewee.ForeignKeyField(TestModelAlpha, backref='betas')
text = peewee.CharField()
Expand All @@ -29,7 +31,7 @@ def __str__(self):
return '<%s id=%s> %s' % (self.__class__.__name__, self.id, self.text)


class TestModelGamma(peewee.Model):
class TestModelGamma(AioModel):
__test__ = False
text = peewee.CharField()
beta = peewee.ForeignKeyField(TestModelBeta, backref='gammas')
Expand All @@ -38,15 +40,15 @@ def __str__(self):
return '<%s id=%s> %s' % (self.__class__.__name__, self.id, self.text)


class UUIDTestModel(peewee.Model):
class UUIDTestModel(AioModel):
id = peewee.UUIDField(primary_key=True, default=uuid.uuid4)
text = peewee.CharField()

def __str__(self):
return '<%s id=%s> %s' % (self.__class__.__name__, self.id, self.text)


class CompositeTestModel(peewee.Model):
class CompositeTestModel(AioModel):
"""A simple "through" table for many-to-many relationship."""
uuid = peewee.ForeignKeyField(UUIDTestModel)
alpha = peewee.ForeignKeyField(TestModelAlpha)
Expand Down

0 comments on commit 11e8d76

Please sign in to comment.