Skip to content

Commit

Permalink
speedup and rewrite obsolette tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kalombos committed Dec 28, 2023
1 parent bd2ed75 commit dda6949
Show file tree
Hide file tree
Showing 8 changed files with 375 additions and 512 deletions.
16 changes: 9 additions & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import asyncio

import pytest
from peewee import sort_models

import peewee_async
from tests.db_config import DB_CLASSES, DB_DEFAULTS
from tests.models import TestModel, UUIDTestModel, TestModelAlpha, TestModelBeta, TestModelGamma, CompositeTestModel
from tests.models import ALL_MODELS

try:
import aiopg
Expand Down Expand Up @@ -34,19 +35,20 @@ async def manager(request):

params = DB_DEFAULTS[db]
database = DB_CLASSES[db](**params)
models = [TestModel, UUIDTestModel, TestModelAlpha,
TestModelBeta, TestModelGamma, CompositeTestModel]
database._allow_sync = False
manager = peewee_async.Manager(database)
with manager.allow_sync():
for model in models:
for model in ALL_MODELS:
model._meta.database = database
model.create_table(True)

yield peewee_async.Manager(database)

with manager.allow_sync():
for model in reversed(sort_models(ALL_MODELS)):
model.delete().execute()
model._meta.database = None
await database.close_async()
for model in reversed(models):
model.drop_table(fail_silently=True)
model._meta.database = None


PG_DBS = [
Expand Down
1 change: 0 additions & 1 deletion tests/db_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
'mysql': MYSQL_DEFAULTS,
'mysql-pool': MYSQL_DEFAULTS
}
DB_OVERRIDES = {}
DB_CLASSES = {
'postgres': peewee_async.PostgresqlDatabase,
'postgres-ext': peewee_asyncext.PostgresqlExtDatabase,
Expand Down
9 changes: 9 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@ def __str__(self):


class TestModelAlpha(peewee.Model):
__test__ = False
text = peewee.CharField()

def __str__(self):
return '<%s id=%s> %s' % (self.__class__.__name__, self.id, self.text)


class TestModelBeta(peewee.Model):
__test__ = False
alpha = peewee.ForeignKeyField(TestModelAlpha, backref='betas')
text = peewee.CharField()

Expand All @@ -28,6 +30,7 @@ def __str__(self):


class TestModelGamma(peewee.Model):
__test__ = False
text = peewee.CharField()
beta = peewee.ForeignKeyField(TestModelBeta, backref='gammas')

Expand All @@ -50,3 +53,9 @@ class CompositeTestModel(peewee.Model):

class Meta:
primary_key = peewee.CompositeKey('uuid', 'alpha')


ALL_MODELS = (
TestModel, UUIDTestModel, TestModelAlpha,
TestModelBeta, TestModelGamma, CompositeTestModel
)
216 changes: 189 additions & 27 deletions tests/test_common.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,14 @@
import asyncio
import uuid

from tests.conftest import all_dbs
from tests.models import TestModel, UUIDTestModel, TestModelAlpha, CompositeTestModel

import peewee as pw
import pytest

@all_dbs
async def test_get_or_none(manager):
"""Test get_or_none manager function."""
text1 = "Test %s" % uuid.uuid4()
text2 = "Test %s" % uuid.uuid4()

obj1 = await manager.create(TestModel, text=text1)
obj2 = await manager.get_or_none(TestModel, text=text1)
obj3 = await manager.get_or_none(TestModel, text=text2)
import peewee_async
from tests.conftest import all_dbs
from tests.db_config import DB_CLASSES, DB_DEFAULTS
from tests.models import UUIDTestModel, TestModelAlpha, CompositeTestModel, TestModel

assert obj1 == obj2
assert obj1 is not None
assert obj2 is not None
assert obj3 is None

@all_dbs
async def test_composite_key(manager):
Expand All @@ -28,25 +19,196 @@ async def test_composite_key(manager):


@all_dbs
async def test_count_query_with_limit(manager):
async def test_multiple_iterate_over_result(manager):

obj1 = await manager.create(TestModel, text="Test 1")
obj2 = await manager.create(TestModel, text="Test 2")

result = await manager.execute(
TestModel.select().order_by(TestModel.text))

assert list(result) == [obj1, obj2]
assert list(result) == [obj1, obj2]


@all_dbs
async def test_indexing_result(manager):

await manager.create(TestModel, text="Test 1")
obj = await manager.create(TestModel, text="Test 2")

result = await manager.execute(
TestModel.select().order_by(TestModel.text)
)
assert obj == result[1]


@all_dbs
async def test_select_many_objects(manager):
text = "Test 1"
obj1 = await manager.create(TestModel, text=text)
text = "Test 2"
obj2 = await manager.create(TestModel, text=text)

select1 = [obj1, obj2]
len1 = len(select1)

select2 = await manager.execute(
TestModel.select().order_by(TestModel.text))
len2 = len([o for o in select2])

assert len1 == len2
for o1, o2 in zip(select1, select2):
assert o1 == o2


@all_dbs
async def test_raw_query(manager):

text = "Test %s" % uuid.uuid4()
await manager.create(TestModel, text=text)

result1 = await manager.execute(TestModel.raw(
'select id, text from testmodel'))
result1 = list(result1)
assert len(result1) == 1
assert isinstance(result1[0], TestModel) is True

result2 = await manager.execute(TestModel.raw(
'select id, text from testmodel').tuples())
result2 = list(result2)
assert len(result2) == 1
assert isinstance(result2[0], tuple) is True

result3 = await manager.execute(TestModel.raw(
'select id, text from testmodel').dicts())
result3 = list(result3)
assert len(result3) == 1
assert isinstance(result3[0], dict) is True


@all_dbs
async def test_get_obj_by_id(manager):
text = "Test %s" % uuid.uuid4()
await manager.create(TestModel, text=text)
obj1 = await manager.create(TestModel, text=text)
obj2 = await manager.get(TestModel, id=obj1.id)

assert obj1 == obj2
assert obj1.id == obj2.id


@all_dbs
async def test_get_obj_by_uuid(manager):

text = "Test %s" % uuid.uuid4()
await manager.create(TestModel, text=text)
obj1 = await manager.create(UUIDTestModel, text=text)
obj2 = await manager.get(UUIDTestModel, id=obj1.id)
assert obj1 == obj2
assert len(str(obj1.id)) == 36

count = await manager.count(TestModel.select().limit(1))
assert count == 1

@all_dbs
async def test_count_query(manager):
async def test_create_uuid_obj(manager):

text = "Test %s" % uuid.uuid4()
await manager.create(TestModel, text=text)
obj = await manager.create(UUIDTestModel, text=text)
assert len(str(obj.id)) == 36


@all_dbs
async def test_allow_sync_is_reverted_for_exc(manager):
try:
with manager.allow_sync():
ununique_text = "ununique_text"
await manager.create(TestModel, text=ununique_text)
await manager.create(TestModel, text=ununique_text)
except pw.IntegrityError:
pass
assert manager.database._allow_sync is False


@all_dbs
async def test_many_requests(manager):

max_connections = getattr(manager.database, 'max_connections', 1)
text = "Test %s" % uuid.uuid4()
await manager.create(TestModel, text=text)
obj = await manager.create(TestModel, text=text)
n = 2 * max_connections # number of requests
done, not_done = await asyncio.wait(
{asyncio.create_task(manager.get(TestModel, id=obj.id)) for _ in range(n)}
)
assert len(done) == n


@all_dbs
async def test_connect_close(manager):

async def get_conn(manager):
await manager.connect()
# await asyncio.sleep(0.05, loop=self.loop)
# NOTE: "private" member access
return manager.database._async_conn


c1 = await get_conn(manager)
c2 = await get_conn(manager)
assert c1 == c2

assert manager.is_connected is True

await manager.close()

assert manager.is_connected is False

done, not_done = await asyncio.wait({asyncio.create_task(get_conn(manager)) for _ in range(3)})

conn = next(iter(done)).result()
assert len(done) == 3
assert manager.is_connected is True
assert all(map(lambda t: t.result() == conn, done)) is True

await manager.close()
assert manager.is_connected is False


@pytest.mark.parametrize(
"params, db_cls",
[
(DB_DEFAULTS[name], db_cls) for name, db_cls in DB_CLASSES.items()
]
)
async def test_deferred_init(params, db_cls):

database = db_cls(None)
assert database.deferred is True

database.init(**params)
assert database.deferred is False

TestModel._meta.database = database
TestModel.create_table(True)
TestModel.drop_table(True)


@pytest.mark.parametrize(
"params, db_cls",
[
(DB_DEFAULTS[name], db_cls) for name, db_cls in DB_CLASSES.items()
]
)
async def test_proxy_database(params, db_cls):

database = pw.Proxy()
TestModel._meta.database = database
manager = peewee_async.Manager(database)

database.initialize(db_cls(**params))

TestModel.create_table(True)

text = "Test %s" % uuid.uuid4()
await manager.create(TestModel, text=text)

count = await manager.count(TestModel.select())
assert count == 3
await manager.get(TestModel, text=text)
TestModel.drop_table(True)
1 change: 0 additions & 1 deletion tests/test_inserting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ async def test_insert_many(manager):

last_id = await manager.execute(query)

TestModel.get()
res = await manager.execute(TestModel.select())
assert len(res) == 2
assert last_id in [m.id for m in res]
Expand Down
Loading

0 comments on commit dda6949

Please sign in to comment.