diff --git a/tests/conftest.py b/tests/conftest.py index a908eef..9ed4032 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,11 @@ import asyncio import pytest +from peewee import sort_models import peewee_async from tests.db_config import DB_CLASSES, DB_DEFAULTS -from tests.models import TestModel, UUIDTestModel, TestModelAlpha, TestModelBeta, TestModelGamma, CompositeTestModel +from tests.models import ALL_MODELS try: import aiopg @@ -34,19 +35,20 @@ async def manager(request): params = DB_DEFAULTS[db] database = DB_CLASSES[db](**params) - models = [TestModel, UUIDTestModel, TestModelAlpha, - TestModelBeta, TestModelGamma, CompositeTestModel] + database._allow_sync = False manager = peewee_async.Manager(database) with manager.allow_sync(): - for model in models: + for model in ALL_MODELS: model._meta.database = database model.create_table(True) yield peewee_async.Manager(database) + + with manager.allow_sync(): + for model in reversed(sort_models(ALL_MODELS)): + model.delete().execute() + model._meta.database = None await database.close_async() - for model in reversed(models): - model.drop_table(fail_silently=True) - model._meta.database = None PG_DBS = [ diff --git a/tests/db_config.py b/tests/db_config.py index 14def81..562157b 100644 --- a/tests/db_config.py +++ b/tests/db_config.py @@ -26,7 +26,6 @@ 'mysql': MYSQL_DEFAULTS, 'mysql-pool': MYSQL_DEFAULTS } -DB_OVERRIDES = {} DB_CLASSES = { 'postgres': peewee_async.PostgresqlDatabase, 'postgres-ext': peewee_asyncext.PostgresqlExtDatabase, diff --git a/tests/models.py b/tests/models.py index f769440..86c271c 100644 --- a/tests/models.py +++ b/tests/models.py @@ -13,6 +13,7 @@ def __str__(self): class TestModelAlpha(peewee.Model): + __test__ = False text = peewee.CharField() def __str__(self): @@ -20,6 +21,7 @@ def __str__(self): class TestModelBeta(peewee.Model): + __test__ = False alpha = peewee.ForeignKeyField(TestModelAlpha, backref='betas') text = peewee.CharField() @@ -28,6 +30,7 @@ def __str__(self): class TestModelGamma(peewee.Model): + __test__ = False text = peewee.CharField() beta = peewee.ForeignKeyField(TestModelBeta, backref='gammas') @@ -50,3 +53,9 @@ class CompositeTestModel(peewee.Model): class Meta: primary_key = peewee.CompositeKey('uuid', 'alpha') + + +ALL_MODELS = ( + TestModel, UUIDTestModel, TestModelAlpha, + TestModelBeta, TestModelGamma, CompositeTestModel +) diff --git a/tests/test_common.py b/tests/test_common.py index 9794f41..bb4d8fe 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -1,23 +1,14 @@ +import asyncio import uuid -from tests.conftest import all_dbs -from tests.models import TestModel, UUIDTestModel, TestModelAlpha, CompositeTestModel - +import peewee as pw +import pytest -@all_dbs -async def test_get_or_none(manager): - """Test get_or_none manager function.""" - text1 = "Test %s" % uuid.uuid4() - text2 = "Test %s" % uuid.uuid4() - - obj1 = await manager.create(TestModel, text=text1) - obj2 = await manager.get_or_none(TestModel, text=text1) - obj3 = await manager.get_or_none(TestModel, text=text2) +import peewee_async +from tests.conftest import all_dbs +from tests.db_config import DB_CLASSES, DB_DEFAULTS +from tests.models import UUIDTestModel, TestModelAlpha, CompositeTestModel, TestModel - assert obj1 == obj2 - assert obj1 is not None - assert obj2 is not None - assert obj3 is None @all_dbs async def test_composite_key(manager): @@ -28,25 +19,196 @@ async def test_composite_key(manager): @all_dbs -async def test_count_query_with_limit(manager): +async def test_multiple_iterate_over_result(manager): + + obj1 = await manager.create(TestModel, text="Test 1") + obj2 = await manager.create(TestModel, text="Test 2") + + result = await manager.execute( + TestModel.select().order_by(TestModel.text)) + + assert list(result) == [obj1, obj2] + assert list(result) == [obj1, obj2] + + +@all_dbs +async def test_indexing_result(manager): + + await manager.create(TestModel, text="Test 1") + obj = await manager.create(TestModel, text="Test 2") + + result = await manager.execute( + TestModel.select().order_by(TestModel.text) + ) + assert obj == result[1] + + +@all_dbs +async def test_select_many_objects(manager): + text = "Test 1" + obj1 = await manager.create(TestModel, text=text) + text = "Test 2" + obj2 = await manager.create(TestModel, text=text) + + select1 = [obj1, obj2] + len1 = len(select1) + + select2 = await manager.execute( + TestModel.select().order_by(TestModel.text)) + len2 = len([o for o in select2]) + + assert len1 == len2 + for o1, o2 in zip(select1, select2): + assert o1 == o2 + + +@all_dbs +async def test_raw_query(manager): + text = "Test %s" % uuid.uuid4() await manager.create(TestModel, text=text) + + result1 = await manager.execute(TestModel.raw( + 'select id, text from testmodel')) + result1 = list(result1) + assert len(result1) == 1 + assert isinstance(result1[0], TestModel) is True + + result2 = await manager.execute(TestModel.raw( + 'select id, text from testmodel').tuples()) + result2 = list(result2) + assert len(result2) == 1 + assert isinstance(result2[0], tuple) is True + + result3 = await manager.execute(TestModel.raw( + 'select id, text from testmodel').dicts()) + result3 = list(result3) + assert len(result3) == 1 + assert isinstance(result3[0], dict) is True + + +@all_dbs +async def test_get_obj_by_id(manager): text = "Test %s" % uuid.uuid4() - await manager.create(TestModel, text=text) + obj1 = await manager.create(TestModel, text=text) + obj2 = await manager.get(TestModel, id=obj1.id) + + assert obj1 == obj2 + assert obj1.id == obj2.id + + +@all_dbs +async def test_get_obj_by_uuid(manager): + text = "Test %s" % uuid.uuid4() - await manager.create(TestModel, text=text) + obj1 = await manager.create(UUIDTestModel, text=text) + obj2 = await manager.get(UUIDTestModel, id=obj1.id) + assert obj1 == obj2 + assert len(str(obj1.id)) == 36 - count = await manager.count(TestModel.select().limit(1)) - assert count == 1 @all_dbs -async def test_count_query(manager): +async def test_create_uuid_obj(manager): + text = "Test %s" % uuid.uuid4() - await manager.create(TestModel, text=text) + obj = await manager.create(UUIDTestModel, text=text) + assert len(str(obj.id)) == 36 + + +@all_dbs +async def test_allow_sync_is_reverted_for_exc(manager): + try: + with manager.allow_sync(): + ununique_text = "ununique_text" + await manager.create(TestModel, text=ununique_text) + await manager.create(TestModel, text=ununique_text) + except pw.IntegrityError: + pass + assert manager.database._allow_sync is False + + +@all_dbs +async def test_many_requests(manager): + + max_connections = getattr(manager.database, 'max_connections', 1) text = "Test %s" % uuid.uuid4() - await manager.create(TestModel, text=text) + obj = await manager.create(TestModel, text=text) + n = 2 * max_connections # number of requests + done, not_done = await asyncio.wait( + {asyncio.create_task(manager.get(TestModel, id=obj.id)) for _ in range(n)} + ) + assert len(done) == n + + +@all_dbs +async def test_connect_close(manager): + + async def get_conn(manager): + await manager.connect() + # await asyncio.sleep(0.05, loop=self.loop) + # NOTE: "private" member access + return manager.database._async_conn + + + c1 = await get_conn(manager) + c2 = await get_conn(manager) + assert c1 == c2 + + assert manager.is_connected is True + + await manager.close() + + assert manager.is_connected is False + + done, not_done = await asyncio.wait({asyncio.create_task(get_conn(manager)) for _ in range(3)}) + + conn = next(iter(done)).result() + assert len(done) == 3 + assert manager.is_connected is True + assert all(map(lambda t: t.result() == conn, done)) is True + + await manager.close() + assert manager.is_connected is False + + +@pytest.mark.parametrize( + "params, db_cls", + [ + (DB_DEFAULTS[name], db_cls) for name, db_cls in DB_CLASSES.items() + ] + +) +async def test_deferred_init(params, db_cls): + + database = db_cls(None) + assert database.deferred is True + + database.init(**params) + assert database.deferred is False + + TestModel._meta.database = database + TestModel.create_table(True) + TestModel.drop_table(True) + + +@pytest.mark.parametrize( + "params, db_cls", + [ + (DB_DEFAULTS[name], db_cls) for name, db_cls in DB_CLASSES.items() + ] + +) +async def test_proxy_database(params, db_cls): + + database = pw.Proxy() + TestModel._meta.database = database + manager = peewee_async.Manager(database) + + database.initialize(db_cls(**params)) + + TestModel.create_table(True) + text = "Test %s" % uuid.uuid4() await manager.create(TestModel, text=text) - - count = await manager.count(TestModel.select()) - assert count == 3 + await manager.get(TestModel, text=text) + TestModel.drop_table(True) diff --git a/tests/test_inserting.py b/tests/test_inserting.py index 60a0e68..725832f 100644 --- a/tests/test_inserting.py +++ b/tests/test_inserting.py @@ -14,7 +14,6 @@ async def test_insert_many(manager): last_id = await manager.execute(query) - TestModel.get() res = await manager.execute(TestModel.select()) assert len(res) == 2 assert last_id in [m.id for m in res] diff --git a/tests/test_obsolete.py b/tests/test_obsolete.py deleted file mode 100644 index 9d4f281..0000000 --- a/tests/test_obsolete.py +++ /dev/null @@ -1,435 +0,0 @@ -""" -peewee-async tests -================== - -Create tests.ini file to configure tests. - -""" -import asyncio -import contextlib -import unittest -import uuid - -import peewee - -import peewee_async -from tests.db_config import DB_DEFAULTS, DB_OVERRIDES, DB_CLASSES -from tests.models import TestModel, TestModelAlpha, TestModelBeta, TestModelGamma, UUIDTestModel, CompositeTestModel - -try: - import aiopg -except ImportError: - aiopg = None - -try: - import aiomysql -except ImportError: - aiomysql = None - - -def setUpModule(): - if not aiopg: - print("aiopg is not installed, ignoring PostgreSQL tests") - for key in list(DB_CLASSES.keys()): - if key.startswith('postgres'): - DB_CLASSES.pop(key) - - if not aiomysql: - print("aiomysql is not installed, ignoring MySQL tests") - for key in list(DB_CLASSES.keys()): - if key.startswith('mysql'): - DB_CLASSES.pop(key) - - loop = asyncio.new_event_loop() - all_databases = load_databases(only=None) - for key, database in all_databases.items(): - connect = database.connect_async(loop=loop) - loop.run_until_complete(connect) - if database._async_conn is not None: - disconnect = database.close_async() - loop.run_until_complete(disconnect) - else: - print("Can't setup connection for %s" % key) - DB_CLASSES.pop(key) - - -def load_managers(*, loop, only): - managers = {} - for key in DB_CLASSES: - if only and key not in only: - continue - params = DB_DEFAULTS.get(key) or {} - params.update(DB_OVERRIDES.get(key) or {}) - database = DB_CLASSES[key](**params) - managers[key] = peewee_async.Manager(database, loop=loop) - return managers - - -def load_databases(*, only): - databases = {} - for key in DB_CLASSES: - if only and key not in only: - continue - params = DB_DEFAULTS.get(key) or {} - params.update(DB_OVERRIDES.get(key) or {}) - databases[key] = DB_CLASSES[key](**params) - return databases - -#################### -# Base tests class # -#################### - - -class BaseManagerTestCase(unittest.TestCase): - only = None - - models = [TestModel, UUIDTestModel, TestModelAlpha, - TestModelBeta, TestModelGamma, CompositeTestModel] - - @classmethod - @contextlib.contextmanager - def manager(cls, objects, allow_sync=False): - for model in cls.models: - model._meta.database = objects.database - if allow_sync: - with objects.allow_sync(): - yield - else: - yield - - def setUp(self): - """Setup the new event loop, and database configs, reset counter. - """ - self.run_count = 0 - self.loop = asyncio.new_event_loop() - self.managers = load_managers(loop=self.loop, only=self.only) - - # Clean up before tests - for _, objects in self.managers.items(): - objects.database.set_allow_sync(False) - with self.manager(objects, allow_sync=True): - for model in self.models: - model.create_table(True) - for model in reversed(self.models): - model.delete().execute() - - def tearDown(self): - """Check if test was actually passed by counter, clean up. - """ - self.assertEqual(len(self.managers), self.run_count) - - for _, objects in self.managers.items(): - self.loop.run_until_complete(objects.close()) - self.loop.close() - - for _, objects in self.managers.items(): - with self.manager(objects, allow_sync=True): - for model in reversed(self.models): - model.drop_table(fail_silently=True) - - self.managers = None - - def run_with_managers(self, test, exclude=None): - """Run test coroutine against available Manager instances. - - test -- coroutine with single parameter, Manager instance - exclude -- exclude list or string for manager key - - Example: - - async def test(objects): - # ... - - run_with_managers(test, exclude=['mysql', 'mysql-pool']) - """ - for key, objects in self.managers.items(): - if exclude is None or (key not in exclude): - with self.manager(objects, allow_sync=False): - self.loop.run_until_complete(test(objects)) - with self.manager(objects, allow_sync=True): - for model in reversed(self.models): - model.delete().execute() - self.run_count += 1 - - -################ -# Common tests # -################ - - -class DatabaseTestCase(unittest.TestCase): - def test_deferred_init(self): - for key in DB_CLASSES: - params = DB_DEFAULTS.get(key) or {} - params.update(DB_OVERRIDES.get(key) or {}) - - database = DB_CLASSES[key](None) - self.assertTrue(database.deferred) - - database.init(**params) - self.assertTrue(not database.deferred) - - TestModel._meta.database = database - TestModel.create_table(True) - TestModel.drop_table(True) - - def test_proxy_database(self): - loop = asyncio.new_event_loop() - database = peewee.Proxy() - TestModel._meta.database = database - objects = peewee_async.Manager(database, loop=loop) - - async def test(objects): - text = "Test %s" % uuid.uuid4() - await objects.create(TestModel, text=text) - await objects.get(TestModel, text=text) - - for key in DB_CLASSES: - params = DB_DEFAULTS.get(key) or {} - params.update(DB_OVERRIDES.get(key) or {}) - database.initialize(DB_CLASSES[key](**params)) - - TestModel.create_table(True) - loop.run_until_complete(test(objects)) - loop.run_until_complete(objects.close()) - TestModel.drop_table(True) - - loop.close() - - -class ManagerTestCase(BaseManagerTestCase): - # only = ['postgres', 'postgres-ext', 'postgres-pool', 'postgres-pool-ext'] - only = None - - def test_connect_close(self): - async def get_conn(objects): - await objects.connect() - # await asyncio.sleep(0.05, loop=self.loop) - # NOTE: "private" member access - return objects.database._async_conn - - async def test(objects): - c1 = await get_conn(objects) - c2 = await get_conn(objects) - self.assertEqual(c1, c2) - self.assertTrue(objects.is_connected) - - await objects.close() - self.assertTrue(not objects.is_connected) - - done, not_done = await asyncio.wait({self.loop.create_task(get_conn(objects)) for _ in range(3)}) - - conn = next(iter(done)).result() - self.assertEqual(len(done), 3) - self.assertTrue(objects.is_connected) - self.assertTrue(all(map(lambda t: t.result() == conn, done))) - - await objects.close() - self.assertTrue(not objects.is_connected) - - self.run_with_managers(test) - - def test_many_requests(self): - async def test(objects): - max_connections = getattr(objects.database, 'max_connections', 1) - text = "Test %s" % uuid.uuid4() - obj = await objects.create(TestModel, text=text) - n = 2 * max_connections # number of requests - done, not_done = await asyncio.wait( - {self.loop.create_task(objects.get(TestModel, id=obj.id)) for _ in range(n)} - ) - self.assertEqual(len(done), n) - - self.run_with_managers(test) - - def test_allow_sync_is_reverted_for_exc(self): - async def test(objects): - try: - with objects.allow_sync(): - ununique_text = "ununique_text" - await objects.create(TestModel, text=ununique_text) - await objects.create(TestModel, text=ununique_text) - except peewee.IntegrityError: - pass - self.assertFalse(objects.database._allow_sync) - - self.run_with_managers(test) - - def test_create_obj(self): - async def test(objects): - text = "Test %s" % uuid.uuid4() - obj = await objects.create(TestModel, text=text) - self.assertTrue(obj is not None) - self.assertEqual(obj.text, text) - - self.run_with_managers(test) - - def test_create_or_get(self): - async def test(objects): - text = "Test %s" % uuid.uuid4() - obj1, created1 = await objects.create_or_get( - TestModel, text=text, data="Data 1") - obj2, created2 = await objects.create_or_get( - TestModel, text=text, data="Data 2") - - self.assertTrue(created1) - self.assertTrue(not created2) - self.assertEqual(obj1, obj2) - self.assertEqual(obj1.data, "Data 1") - self.assertEqual(obj2.data, "Data 1") - - self.run_with_managers(test) - - def test_get_or_create(self): - async def test(objects): - text = "Test %s" % uuid.uuid4() - - obj1, created1 = await objects.get_or_create( - TestModel, text=text, defaults={'data': "Data 1"}) - obj2, created2 = await objects.get_or_create( - TestModel, text=text, defaults={'data': "Data 2"}) - - self.assertTrue(created1) - self.assertTrue(not created2) - self.assertEqual(obj1, obj2) - self.assertEqual(obj1.data, "Data 1") - self.assertEqual(obj2.data, "Data 1") - - self.run_with_managers(test) - - def test_create_uuid_obj(self): - async def test(objects): - text = "Test %s" % uuid.uuid4() - obj = await objects.create(UUIDTestModel, text=text) - self.assertEqual(len(str(obj.id)), 36) - - self.run_with_managers(test, exclude=['mysql', 'mysql-pool']) - - def test_get_obj_by_id(self): - async def test(objects): - text = "Test %s" % uuid.uuid4() - obj1 = await objects.create(TestModel, text=text) - obj2 = await objects.get(TestModel, id=obj1.id) - self.assertEqual(obj1, obj2) - self.assertEqual(obj1.id, obj2.id) - - self.run_with_managers(test) - - def test_get_obj_by_uuid(self): - async def test(objects): - text = "Test %s" % uuid.uuid4() - obj1 = await objects.create(UUIDTestModel, text=text) - obj2 = await objects.get(UUIDTestModel, id=obj1.id) - self.assertEqual(obj1, obj2) - self.assertEqual(len(str(obj1.id)), 36) - - self.run_with_managers(test) - - def test_raw_query(self): - async def test(objects): - text = "Test %s" % uuid.uuid4() - await objects.create(TestModel, text=text) - - result1 = await objects.execute(TestModel.raw( - 'select id, text from testmodel')) - result1 = list(result1) - self.assertEqual(len(result1), 1) - self.assertTrue(isinstance(result1[0], TestModel)) - - result2 = await objects.execute(TestModel.raw( - 'select id, text from testmodel').tuples()) - result2 = list(result2) - self.assertEqual(len(result2), 1) - self.assertTrue(isinstance(result2[0], tuple)) - - result3 = await objects.execute(TestModel.raw( - 'select id, text from testmodel').dicts()) - result3 = list(result3) - self.assertEqual(len(result3), 1) - self.assertTrue(isinstance(result3[0], dict)) - - self.run_with_managers(test) - - def test_select_many_objects(self): - async def test(objects): - text = "Test 1" - obj1 = await objects.create(TestModel, text=text) - text = "Test 2" - obj2 = await objects.create(TestModel, text=text) - - select1 = [obj1, obj2] - len1 = len(select1) - - select2 = await objects.execute( - TestModel.select().order_by(TestModel.text)) - len2 = len([o for o in select2]) - - self.assertEqual(len1, len2) - for o1, o2 in zip(select1, select2): - self.assertEqual(o1, o2) - - self.run_with_managers(test) - - def test_indexing_result(self): - async def test(objects): - await objects.create(TestModel, text="Test 1") - obj = await objects.create(TestModel, text="Test 2") - result = await objects.execute( - TestModel.select().order_by(TestModel.text)) - self.assertEqual(obj, result[1]) - - self.run_with_managers(test) - - def test_multiple_iterate_over_result(self): - async def test(objects): - obj1 = await objects.create(TestModel, text="Test 1") - obj2 = await objects.create(TestModel, text="Test 2") - result = await objects.execute( - TestModel.select().order_by(TestModel.text)) - self.assertEqual(list(result), [obj1, obj2]) - self.assertEqual(list(result), [obj1, obj2]) - - self.run_with_managers(test) - - def test_update_obj(self): - async def test(objects): - text = "Test %s" % uuid.uuid4() - obj1 = await objects.create(TestModel, text=text) - - obj1.text = "Test update object" - await objects.update(obj1) - - obj2 = await objects.get(TestModel, id=obj1.id) - self.assertEqual(obj2.text, "Test update object") - - self.run_with_managers(test) - - def test_delete_obj(self): - async def test(objects): - text = "Test %s" % uuid.uuid4() - obj1 = await objects.create(TestModel, text=text) - - obj2 = await objects.get(TestModel, id=obj1.id) - - await objects.delete(obj2) - try: - obj3 = await objects.get(TestModel, id=obj1.id) - except TestModel.DoesNotExist: - obj3 = None - self.assertTrue(obj3 is None, "Error, object wasn't deleted") - - self.run_with_managers(test) - - def test_scalar_query(self): - async def test(objects): - text = "Test %s" % uuid.uuid4() - await objects.create(TestModel, text=text) - text = "Test %s" % uuid.uuid4() - await objects.create(TestModel, text=text) - - fn = peewee.fn.Count(TestModel.id) - count = await objects.scalar(TestModel.select(fn)) - self.assertEqual(count, 2) - - self.run_with_managers(test) - diff --git a/tests/test_prefetch.py b/tests/test_prefetch.py deleted file mode 100644 index 821b717..0000000 --- a/tests/test_prefetch.py +++ /dev/null @@ -1,41 +0,0 @@ -import peewee -import pytest - -from tests.conftest import all_dbs -from tests.models import TestModelAlpha, TestModelBeta, TestModelGamma - - -@all_dbs -@pytest.mark.parametrize( - "prefetch_type", - peewee.PREFETCH_TYPE.values() -) -async def test_prefetch(manager, prefetch_type): - alpha_1 = await manager.create( - TestModelAlpha, text='Alpha 1') - alpha_2 = await manager.create( - TestModelAlpha, text='Alpha 2') - - beta_11 = await manager.create( - TestModelBeta, alpha=alpha_1, text='Beta 11') - beta_12 = await manager.create( - TestModelBeta, alpha=alpha_1, text='Beta 12') - _ = await manager.create( - TestModelBeta, alpha=alpha_2, text='Beta 21') - _ = await manager.create( - TestModelBeta, alpha=alpha_2, text='Beta 22') - - gamma_111 = await manager.create( - TestModelGamma, beta=beta_11, text='Gamma 111') - gamma_112 = await manager.create( - TestModelGamma, beta=beta_11, text='Gamma 112') - - result = await manager.prefetch( - TestModelAlpha.select().order_by(TestModelAlpha.id), - TestModelBeta.select().order_by(TestModelBeta.id), - TestModelGamma.select().order_by(TestModelGamma.id), - prefetch_type=prefetch_type, - ) - assert tuple(result) == (alpha_1, alpha_2) - assert tuple(result[0].betas) == (beta_11, beta_12) - assert tuple(result[0].betas[0].gammas) == (gamma_111, gamma_112) diff --git a/tests/test_shortcuts.py b/tests/test_shortcuts.py new file mode 100644 index 0000000..ad639a5 --- /dev/null +++ b/tests/test_shortcuts.py @@ -0,0 +1,168 @@ +import uuid + +import peewee +import peewee as pw +import pytest + +from tests.conftest import all_dbs +from tests.models import TestModel, TestModelAlpha, TestModelBeta, TestModelGamma + + +@all_dbs +@pytest.mark.parametrize( + "prefetch_type", + peewee.PREFETCH_TYPE.values() +) +async def test_prefetch(manager, prefetch_type): + alpha_1 = await manager.create( + TestModelAlpha, text='Alpha 1') + alpha_2 = await manager.create( + TestModelAlpha, text='Alpha 2') + + beta_11 = await manager.create( + TestModelBeta, alpha=alpha_1, text='Beta 11') + beta_12 = await manager.create( + TestModelBeta, alpha=alpha_1, text='Beta 12') + _ = await manager.create( + TestModelBeta, alpha=alpha_2, text='Beta 21') + _ = await manager.create( + TestModelBeta, alpha=alpha_2, text='Beta 22') + + gamma_111 = await manager.create( + TestModelGamma, beta=beta_11, text='Gamma 111') + gamma_112 = await manager.create( + TestModelGamma, beta=beta_11, text='Gamma 112') + + result = await manager.prefetch( + TestModelAlpha.select().order_by(TestModelAlpha.id), + TestModelBeta.select().order_by(TestModelBeta.id), + TestModelGamma.select().order_by(TestModelGamma.id), + prefetch_type=prefetch_type, + ) + assert tuple(result) == (alpha_1, alpha_2) + assert tuple(result[0].betas) == (beta_11, beta_12) + assert tuple(result[0].betas[0].gammas) == (gamma_111, gamma_112) + + +@all_dbs +async def test_get_or_none(manager): + """Test get_or_none manager function.""" + text1 = "Test %s" % uuid.uuid4() + text2 = "Test %s" % uuid.uuid4() + + obj1 = await manager.create(TestModel, text=text1) + obj2 = await manager.get_or_none(TestModel, text=text1) + obj3 = await manager.get_or_none(TestModel, text=text2) + + assert obj1 == obj2 + assert obj1 is not None + assert obj2 is not None + assert obj3 is None + + +@all_dbs +async def test_count_query_with_limit(manager): + text = "Test %s" % uuid.uuid4() + await manager.create(TestModel, text=text) + text = "Test %s" % uuid.uuid4() + await manager.create(TestModel, text=text) + text = "Test %s" % uuid.uuid4() + await manager.create(TestModel, text=text) + + count = await manager.count(TestModel.select().limit(1)) + assert count == 1 + + +@all_dbs +async def test_count_query(manager): + text = "Test %s" % uuid.uuid4() + await manager.create(TestModel, text=text) + text = "Test %s" % uuid.uuid4() + await manager.create(TestModel, text=text) + text = "Test %s" % uuid.uuid4() + await manager.create(TestModel, text=text) + + count = await manager.count(TestModel.select()) + assert count == 3 + + +@all_dbs +async def test_scalar_query(manager): + + text = "Test %s" % uuid.uuid4() + await manager.create(TestModel, text=text) + text = "Test %s" % uuid.uuid4() + await manager.create(TestModel, text=text) + + fn = pw.fn.Count(TestModel.id) + count = await manager.scalar(TestModel.select(fn)) + + assert count == 2 + + +@all_dbs +async def test_delete_obj(manager): + text = "Test %s" % uuid.uuid4() + obj1 = await manager.create(TestModel, text=text) + obj2 = await manager.get(TestModel, id=obj1.id) + + await manager.delete(obj2) + + obj3 = await manager.get_or_none(TestModel, id=obj1.id) + assert obj3 is None + + +@all_dbs +async def test_update_obj(manager): + + text = "Test %s" % uuid.uuid4() + obj1 = await manager.create(TestModel, text=text) + + obj1.text = "Test update object" + await manager.update(obj1) + + obj2 = await manager.get(TestModel, id=obj1.id) + assert obj2.text == "Test update object" + + +@all_dbs +async def test_create_obj(manager): + + text = "Test %s" % uuid.uuid4() + obj = await manager.create(TestModel, text=text) + assert obj is not None + assert obj.text == text + + +@all_dbs +async def test_create_or_get(manager): + text = "Test %s" % uuid.uuid4() + obj1, created1 = await manager.create_or_get( + TestModel, text=text, data="Data 1") + obj2, created2 = await manager.create_or_get( + TestModel, text=text, data="Data 2") + + assert created1 is True + assert created2 is False + assert obj1 == obj2 + assert obj1.data == "Data 1" + assert obj2.data == "Data 1" + + +@all_dbs +async def test_get_or_create(manager): + + text = "Test %s" % uuid.uuid4() + + obj1, created1 = await manager.get_or_create( + TestModel, text=text, defaults={'data': "Data 1"}) + obj2, created2 = await manager.get_or_create( + TestModel, text=text, defaults={'data': "Data 2"}) + + assert created1 is True + assert created2 is False + assert obj1 == obj2 + assert obj1.data == "Data 1" + assert obj2.data == "Data 1" + +