Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/v0.4.0' into feat(async)
Browse files Browse the repository at this point in the history
# Conflicts:
#	graphene_mongo/converter.py
#	graphene_mongo/fields.py
#	pyproject.toml
  • Loading branch information
arunsureshkumar committed Apr 11, 2023
2 parents 5ff73e6 + 9581ad2 commit 6eb5d23
Show file tree
Hide file tree
Showing 12 changed files with 2,016 additions and 104 deletions.
8 changes: 6 additions & 2 deletions graphene_mongo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from .fields import MongoengineConnectionField
from .fields_async import AsyncMongoengineConnectionField

from .types import MongoengineObjectType, MongoengineInputType, MongoengineInterfaceType
from .types_async import AsyncMongoengineObjectType

__version__ = "0.1.1"

__all__ = [
"__version__",
"MongoengineObjectType",
"AsyncMongoengineObjectType",
"MongoengineInputType",
"MongoengineInterfaceType",
"MongoengineConnectionField"
]
"MongoengineConnectionField",
"AsyncMongoengineConnectionField"
]
296 changes: 242 additions & 54 deletions graphene_mongo/converter.py

Large diffs are not rendered by default.

63 changes: 33 additions & 30 deletions graphene_mongo/fields.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import absolute_import

import logging
from collections import OrderedDict
from functools import partial, reduce

Expand All @@ -21,8 +22,6 @@
from mongoengine.base import get_document
from promise import Promise
from pymongo.errors import OperationFailure
from asgiref.sync import sync_to_async
from concurrent.futures import ThreadPoolExecutor

from .advanced_types import (
FileFieldType,
Expand All @@ -33,7 +32,7 @@
from .converter import convert_mongoengine_field, MongoEngineConversionError
from .registry import get_global_registry
from .utils import get_model_reference_fields, get_query_fields, find_skip_and_limit, \
connection_from_iterables
connection_from_iterables, ExecutorEnum
import pymongo

PYMONGO_VERSION = tuple(pymongo.version_tuple[:2])
Expand All @@ -49,6 +48,10 @@ def __init__(self, type, *args, **kwargs):
self._get_queryset = get_queryset
super(MongoengineConnectionField, self).__init__(type, *args, **kwargs)

@property
def executor(self):
return ExecutorEnum.SYNC

@property
def type(self):
from .types import MongoengineObjectType
Expand Down Expand Up @@ -137,7 +140,7 @@ def is_filterable(k):
return False
try:
converted = convert_mongoengine_field(
getattr(self.model, k), self.registry
getattr(self.model, k), self.registry, self.executor
)
except MongoEngineConversionError:
return False
Expand Down Expand Up @@ -315,7 +318,7 @@ def get_queryset(self, model, info, required_fields=None, skip=None, limit=None,
skip)
return model.objects(**args).no_dereference().only(*required_fields).order_by(self.order_by)

async def default_resolver(self, _root, info, required_fields=None, resolved=None, **args):
def default_resolver(self, _root, info, required_fields=None, resolved=None, **args):
if required_fields is None:
required_fields = list()
args = args or {}
Expand Down Expand Up @@ -358,8 +361,7 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non

if isinstance(items, QuerySet):
try:
count = await sync_to_async(items.count, thread_sensitive=False,
executor=ThreadPoolExecutor())(with_limit_and_skip=True)
count = items.count(with_limit_and_skip=True)
except OperationFailure:
count = len(items)
else:
Expand Down Expand Up @@ -402,13 +404,13 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non
args_copy[key] = args_copy[key].value

if PYMONGO_VERSION >= (3, 7):
count = await sync_to_async(
(mongoengine.get_db()[self.model._get_collection_name()]).count_documents,
thread_sensitive=False,
executor=ThreadPoolExecutor())(args_copy)
if hasattr(self.model, '_meta') and 'db_alias' in self.model._meta:
count = (mongoengine.get_db(self.model._meta['db_alias'])[
self.model._get_collection_name()]).count_documents(args_copy)
else:
count = (mongoengine.get_db()[self.model._get_collection_name()]).count_documents(args_copy)
else:
count = await sync_to_async(self.model.objects(args_copy).count, thread_sensitive=False,
executor=ThreadPoolExecutor())()
count = self.model.objects(args_copy).count()
if count != 0:
skip, limit, reverse = find_skip_and_limit(first=first, after=after, last=last, before=before,
count=count)
Expand Down Expand Up @@ -470,7 +472,7 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non
connection.list_length = list_length
return connection

async def chained_resolver(self, resolver, is_partial, root, info, **args):
def chained_resolver(self, resolver, is_partial, root, info, **args):

for key, value in dict(args).items():
if value is None:
Expand Down Expand Up @@ -514,13 +516,13 @@ async def chained_resolver(self, resolver, is_partial, root, info, **args):
elif not isinstance(resolved[0], DBRef):
return resolved
else:
return await self.default_resolver(root, info, required_fields, **args_copy)
return self.default_resolver(root, info, required_fields, **args_copy)
elif isinstance(resolved, QuerySet):
args.update(resolved._query)
args_copy = args.copy()
for arg_name, arg in args.copy().items():
if "." in arg_name or arg_name not in self.model._fields_ordered + (
'first', 'last', 'before', 'after') + tuple(self.filter_args.keys()):
if "." in arg_name or arg_name not in self.model._fields_ordered \
+ ('first', 'last', 'before', 'after') + tuple(self.filter_args.keys()):
args_copy.pop(arg_name)
if arg_name == '_id' and isinstance(arg, dict):
operation = list(arg.keys())[0]
Expand All @@ -540,37 +542,38 @@ async def chained_resolver(self, resolver, is_partial, root, info, **args):
operation = list(arg.keys())[0]
args_copy[arg_name + operation.replace('$', '__')] = arg[operation]
del args_copy[arg_name]

return await self.default_resolver(root, info, required_fields, resolved=resolved, **args_copy)
return self.default_resolver(root, info, required_fields, resolved=resolved, **args_copy)
elif isinstance(resolved, Promise):
return resolved.value
else:
return await resolved
return resolved

return await self.default_resolver(root, info, required_fields, **args)
return self.default_resolver(root, info, required_fields, **args)

@classmethod
async def connection_resolver(cls, resolver, connection_type, root, info, **args):
def connection_resolver(cls, resolver, connection_type, root, info, **args):
if root:
for key, value in root.__dict__.items():
if value:
try:
setattr(root, key, from_global_id(value)[1])
except Exception:
pass
iterable = await resolver(root, info, **args)
except Exception as error:
logging.error("Exception Occurred: ", exc_info=error)
iterable = resolver(root, info, **args)

if isinstance(connection_type, graphene.NonNull):
connection_type = connection_type.of_type

on_resolve = partial(cls.resolve_connection, connection_type, args)

if Promise.is_thenable(iterable):
on_resolve = partial(cls.resolve_connection, connection_type, args)
iterable = Promise.resolve(iterable).then(on_resolve).value
return await sync_to_async(cls.resolve_connection, thread_sensitive=False,
executor=ThreadPoolExecutor())(connection_type, args, iterable)
return Promise.resolve(iterable).then(on_resolve)

return on_resolve(iterable)

def get_resolver(self, parent_resolver):
super_resolver = self.resolver or parent_resolver
resolver = partial(
self.chained_resolver, super_resolver, isinstance(super_resolver, partial)
)

return partial(self.connection_resolver, resolver, self.type)
Loading

0 comments on commit 6eb5d23

Please sign in to comment.