Skip to content

Commit

Permalink
Merge branch 'graphql-python:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
arunsureshkumar authored Apr 11, 2023
2 parents ec1c7af + b56fb6c commit 1f0250b
Show file tree
Hide file tree
Showing 13 changed files with 185 additions and 770 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ query = '''
}
}
'''
result = schema.execute(query)
result = await schema.execute_async(query)
```

To learn more check out the following [examples](examples/):
Expand Down
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ Then you can simply query the schema:
}
}
'''
result = schema.execute(query)
result = await schema.execute_async(query)
To learn more check out the `Flask MongoEngine example <https://github.com/graphql-python/graphene-mongo/tree/master/examples/flask_mongoengine>`__

4 changes: 2 additions & 2 deletions examples/falcon_mongoengine/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def on_post(self, req, resp):
class GraphQLResource:
def on_get(self, req, resp):
query = req.params["query"]
result = schema.execute(query)
result = await schema.execute_async(query)

if result.data:
data_ret = {"data": result.data}
Expand All @@ -32,7 +32,7 @@ def on_get(self, req, resp):

def on_post(self, req, resp):
query = req.params["query"]
result = schema.execute(query)
result = await schema.execute_async(query)
if result.data:
data_ret = {"data": result.data}
resp.status = falcon.HTTP_200
Expand Down
62 changes: 33 additions & 29 deletions graphene_mongo/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from . import advanced_types
from .utils import import_single_dispatch, get_field_description, get_query_fields
from concurrent.futures import ThreadPoolExecutor, as_completed
from asgiref.sync import sync_to_async

singledispatch = import_single_dispatch()

Expand Down Expand Up @@ -42,6 +43,14 @@ def convert_field_to_id(field, registry=None):
)


@convert_mongoengine_field.register(mongoengine.Decimal128Field)
@convert_mongoengine_field.register(mongoengine.DecimalField)
def convert_field_to_decimal(field, registry=None):
return graphene.Decimal(
description=get_field_description(field, registry), required=field.required
)


@convert_mongoengine_field.register(mongoengine.IntField)
@convert_mongoengine_field.register(mongoengine.LongField)
@convert_mongoengine_field.register(mongoengine.SequenceField)
Expand All @@ -58,21 +67,13 @@ def convert_field_to_boolean(field, registry=None):
)


@convert_mongoengine_field.register(mongoengine.DecimalField)
@convert_mongoengine_field.register(mongoengine.FloatField)
def convert_field_to_float(field, registry=None):
return graphene.Float(
description=get_field_description(field, registry), required=field.required
)


@convert_mongoengine_field.register(mongoengine.Decimal128Field)
def convert_field_to_decimal(field, registry=None):
return graphene.Decimal(
description=get_field_description(field, registry), required=field.required
)


@convert_mongoengine_field.register(mongoengine.DateTimeField)
def convert_field_to_datetime(field, registry=None):
return graphene.DateTime(
Expand Down Expand Up @@ -246,7 +247,7 @@ def convert_field_to_union(field, registry=None):
Meta = type("Meta", (object,), {"types": tuple(_types)})
_union = type(name, (graphene.Union,), {"Meta": Meta})

def reference_resolver(root, *args, **kwargs):
async def reference_resolver(root, *args, **kwargs):
de_referenced = getattr(root, field.name or field.db_name)
if de_referenced:
document = get_document(de_referenced["_cls"])
Expand All @@ -265,13 +266,14 @@ def reference_resolver(root, *args, **kwargs):
item = to_snake_case(each)
if item in document._fields_ordered + tuple(filter_args):
queried_fields.append(item)
return document.objects().no_dereference().only(*list(
set(list(_type._meta.required_fields) + queried_fields))).get(
pk=de_referenced["_ref"].id)
return document()
return await sync_to_async(document.objects().no_dereference().only(*list(
set(list(_type._meta.required_fields) + queried_fields))).get, thread_sensitive=False,
executor=ThreadPoolExecutor())(pk=de_referenced["_ref"].id)
return await sync_to_async(document, thread_sensitive=False,
executor=ThreadPoolExecutor())()
return None

def lazy_reference_resolver(root, *args, **kwargs):
async def lazy_reference_resolver(root, *args, **kwargs):
document = getattr(root, field.name or field.db_name)
if document:
queried_fields = list()
Expand All @@ -288,10 +290,11 @@ def lazy_reference_resolver(root, *args, **kwargs):
if item in document.document_type._fields_ordered + tuple(filter_args):
queried_fields.append(item)
_type = registry.get_type_for_model(document.document_type)
return document.document_type.objects().no_dereference().only(
*(set((list(_type._meta.required_fields) + queried_fields)))).get(
pk=document.pk)
return document.document_type()
return await sync_to_async(document.document_type.objects().no_dereference().only(
*(set((list(_type._meta.required_fields) + queried_fields)))).get, thread_sensitive=False,
executor=ThreadPoolExecutor())(pk=document.pk)
return await sync_to_async(document.document_type, thread_sensitive=False,
executor=ThreadPoolExecutor())()
return None

if isinstance(field, mongoengine.GenericLazyReferenceField):
Expand Down Expand Up @@ -327,7 +330,7 @@ def lazy_reference_resolver(root, *args, **kwargs):
def convert_field_to_dynamic(field, registry=None):
model = field.document_type

def reference_resolver(root, *args, **kwargs):
async def reference_resolver(root, *args, **kwargs):
document = getattr(root, field.name or field.db_name)
if document:
queried_fields = list()
Expand All @@ -341,12 +344,12 @@ def reference_resolver(root, *args, **kwargs):
item = to_snake_case(each)
if item in field.document_type._fields_ordered + tuple(filter_args):
queried_fields.append(item)
return field.document_type.objects().no_dereference().only(
*(set(list(_type._meta.required_fields) + queried_fields))).get(
pk=document.id)
return await sync_to_async(field.document_type.objects().no_dereference().only(
*(set(list(_type._meta.required_fields) + queried_fields))).get, thread_sensitive=False,
executor=ThreadPoolExecutor())(pk=document.id)
return None

def cached_reference_resolver(root, *args, **kwargs):
async def cached_reference_resolver(root, *args, **kwargs):
if field:
queried_fields = list()
_type = registry.get_type_for_model(field.document_type)
Expand All @@ -359,9 +362,10 @@ def cached_reference_resolver(root, *args, **kwargs):
item = to_snake_case(each)
if item in field.document_type._fields_ordered + tuple(filter_args):
queried_fields.append(item)
return field.document_type.objects().no_dereference().only(
return await sync_to_async(field.document_type.objects().no_dereference().only(
*(set(
list(_type._meta.required_fields) + queried_fields))).get(
list(_type._meta.required_fields) + queried_fields))).get, thread_sensitive=False,
executor=ThreadPoolExecutor())(
pk=getattr(root, field.name or field.db_name))
return None

Expand Down Expand Up @@ -394,7 +398,7 @@ def dynamic_type():
def convert_lazy_field_to_dynamic(field, registry=None):
model = field.document_type

def lazy_resolver(root, *args, **kwargs):
async def lazy_resolver(root, *args, **kwargs):
document = getattr(root, field.name or field.db_name)
if document:
queried_fields = list()
Expand All @@ -408,9 +412,9 @@ def lazy_resolver(root, *args, **kwargs):
item = to_snake_case(each)
if item in document.document_type._fields_ordered + tuple(filter_args):
queried_fields.append(item)
return document.document_type.objects().no_dereference().only(
*(set((list(_type._meta.required_fields) + queried_fields)))).get(
pk=document.pk)
return await sync_to_async(document.document_type.objects().no_dereference().only(
*(set((list(_type._meta.required_fields) + queried_fields)))).get, thread_sensitive=False,
executor=ThreadPoolExecutor())(pk=document.pk)
return None

def dynamic_type():
Expand Down
52 changes: 26 additions & 26 deletions graphene_mongo/fields.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import absolute_import

import logging
from collections import OrderedDict
from functools import partial, reduce

Expand All @@ -22,6 +21,8 @@
from mongoengine.base import get_document
from promise import Promise
from pymongo.errors import OperationFailure
from asgiref.sync import sync_to_async
from concurrent.futures import ThreadPoolExecutor

from .advanced_types import (
FileFieldType,
Expand Down Expand Up @@ -314,7 +315,7 @@ def get_queryset(self, model, info, required_fields=None, skip=None, limit=None,
skip)
return model.objects(**args).no_dereference().only(*required_fields).order_by(self.order_by)

def default_resolver(self, _root, info, required_fields=None, resolved=None, **args):
async def default_resolver(self, _root, info, required_fields=None, resolved=None, **args):
if required_fields is None:
required_fields = list()
args = args or {}
Expand Down Expand Up @@ -357,7 +358,8 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a

if isinstance(items, QuerySet):
try:
count = items.count(with_limit_and_skip=True)
count = await sync_to_async(items.count, thread_sensitive=False,
executor=ThreadPoolExecutor())(with_limit_and_skip=True)
except OperationFailure:
count = len(items)
else:
Expand Down Expand Up @@ -400,12 +402,13 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a
args_copy[key] = args_copy[key].value

if PYMONGO_VERSION >= (3, 7):
if hasattr(self.model, '_meta') and 'db_alias' in self.model._meta:
count = (mongoengine.get_db(self.model._meta['db_alias'])[self.model._get_collection_name()]).count_documents(args_copy)
else:
count = (mongoengine.get_db()[self.model._get_collection_name()]).count_documents(args_copy)
count = await sync_to_async(
(mongoengine.get_db()[self.model._get_collection_name()]).count_documents,
thread_sensitive=False,
executor=ThreadPoolExecutor())(args_copy)
else:
count = self.model.objects(args_copy).count()
count = await sync_to_async(self.model.objects(args_copy).count, thread_sensitive=False,
executor=ThreadPoolExecutor())()
if count != 0:
skip, limit, reverse = find_skip_and_limit(first=first, after=after, last=last, before=before,
count=count)
Expand Down Expand Up @@ -467,7 +470,7 @@ def default_resolver(self, _root, info, required_fields=None, resolved=None, **a
connection.list_length = list_length
return connection

def chained_resolver(self, resolver, is_partial, root, info, **args):
async def chained_resolver(self, resolver, is_partial, root, info, **args):

for key, value in dict(args).items():
if value is None:
Expand Down Expand Up @@ -511,13 +514,13 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
elif not isinstance(resolved[0], DBRef):
return resolved
else:
return self.default_resolver(root, info, required_fields, **args_copy)
return await self.default_resolver(root, info, required_fields, **args_copy)
elif isinstance(resolved, QuerySet):
args.update(resolved._query)
args_copy = args.copy()
for arg_name, arg in args.copy().items():
if "." in arg_name or arg_name not in self.model._fields_ordered \
+ ('first', 'last', 'before', 'after') + tuple(self.filter_args.keys()):
if "." in arg_name or arg_name not in self.model._fields_ordered + (
'first', 'last', 'before', 'after') + tuple(self.filter_args.keys()):
args_copy.pop(arg_name)
if arg_name == '_id' and isinstance(arg, dict):
operation = list(arg.keys())[0]
Expand All @@ -537,38 +540,35 @@ def chained_resolver(self, resolver, is_partial, root, info, **args):
operation = list(arg.keys())[0]
args_copy[arg_name + operation.replace('$', '__')] = arg[operation]
del args_copy[arg_name]
return self.default_resolver(root, info, required_fields, resolved=resolved, **args_copy)

return await self.default_resolver(root, info, required_fields, resolved=resolved, **args_copy)
elif isinstance(resolved, Promise):
return resolved.value
else:
return resolved
return await resolved

return self.default_resolver(root, info, required_fields, **args)
return await self.default_resolver(root, info, required_fields, **args)

@classmethod
def connection_resolver(cls, resolver, connection_type, root, info, **args):
async def connection_resolver(cls, resolver, connection_type, root, info, **args):
if root:
for key, value in root.__dict__.items():
if value:
try:
setattr(root, key, from_global_id(value)[1])
except Exception as error:
logging.error("Exception Occurred: ", exc_info=error)
iterable = resolver(root, info, **args)

except Exception:
pass
iterable = await resolver(root, info, **args)
if isinstance(connection_type, graphene.NonNull):
connection_type = connection_type.of_type

on_resolve = partial(cls.resolve_connection, connection_type, args)

if Promise.is_thenable(iterable):
return Promise.resolve(iterable).then(on_resolve)

return on_resolve(iterable)
return await sync_to_async(cls.resolve_connection, thread_sensitive=False,
executor=ThreadPoolExecutor())(connection_type, args, iterable)

def get_resolver(self, parent_resolver):
super_resolver = self.resolver or parent_resolver
resolver = partial(
self.chained_resolver, super_resolver, isinstance(super_resolver, partial)
)

return partial(self.connection_resolver, resolver, self.type)
4 changes: 2 additions & 2 deletions graphene_mongo/tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def test_should_boolean_convert_boolean():
assert_conversion(mongoengine.BooleanField, graphene.Boolean)


def test_should_decimal_convert_float():
assert_conversion(mongoengine.DecimalField, graphene.Float)
def test_should_decimal_convert_decimal():
assert_conversion(mongoengine.DecimalField, graphene.Decimal)


def test_should_float_convert_float():
Expand Down
8 changes: 4 additions & 4 deletions graphene_mongo/tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,16 @@ def test_field_args_with_unconverted_field():
assert set(field.field_args.keys()) == set(field_args)


def test_default_resolver_with_colliding_objects_field():
async def test_default_resolver_with_colliding_objects_field():
field = MongoengineConnectionField(nodes.ErroneousModelNode)

connection = field.default_resolver(None, {})
connection = await field.default_resolver(None, {})
assert 0 == len(connection.iterable)


def test_default_resolver_connection_list_length(fixtures):
async def test_default_resolver_connection_list_length(fixtures):
field = MongoengineConnectionField(nodes.ArticleNode)

connection = field.default_resolver(None, {}, **{"first": 1})
connection = await field.default_resolver(None, {}, **{"first": 1})
assert hasattr(connection, "list_length")
assert connection.list_length == 1
12 changes: 6 additions & 6 deletions graphene_mongo/tests/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
from .types import ArticleInput, EditorInput


def test_should_create(fixtures):
async def test_should_create(fixtures):
class CreateArticle(graphene.Mutation):
class Arguments:
article = ArticleInput(required=True)

article = graphene.Field(ArticleNode)

def mutate(self, info, article):
async def mutate(self, info, article):
article = Article(**article)
article.save()

Expand All @@ -39,20 +39,20 @@ class Mutation(graphene.ObjectType):
"""
expected = {"createArticle": {"article": {"headline": "My Article"}}}
schema = graphene.Schema(query=Query, mutation=Mutation)
result = schema.execute(query)
result = await schema.execute_async(query)
assert not result.errors
assert result.data == expected


def test_should_update(fixtures):
async def test_should_update(fixtures):
class UpdateEditor(graphene.Mutation):
class Arguments:
id = graphene.ID(required=True)
editor = EditorInput(required=True)

editor = graphene.Field(EditorNode)

def mutate(self, info, id, editor):
async def mutate(self, info, id, editor):
editor_to_update = Editor.objects.get(id=id)
for key, value in editor.items():
if value:
Expand Down Expand Up @@ -85,7 +85,7 @@ class Mutation(graphene.ObjectType):
"""
expected = {"updateEditor": {"editor": {"firstName": "Penny", "lastName": "Lane"}}}
schema = graphene.Schema(query=Query, mutation=Mutation)
result = schema.execute(query)
result = await schema.execute_async(query)
# print(result.data)
assert not result.errors
assert result.data == expected
Loading

0 comments on commit 1f0250b

Please sign in to comment.