Skip to content

Commit

Permalink
Support Async, Test Case Added
Browse files Browse the repository at this point in the history
  • Loading branch information
arunsureshkumar committed Apr 11, 2023
1 parent 7c99bc1 commit 2fe5008
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 104 deletions.
52 changes: 31 additions & 21 deletions graphene_mongo/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,10 @@ def reference_resolver(root, *args, **kwargs):
querying_union_types.remove('__typename')
to_resolve_models = list()
for each in querying_union_types:
to_resolve_models.append(registry._registry_string_map[each])
if executor == ExecutorEnum.SYNC:
to_resolve_models.append(registry._registry_string_map[each])
else:
to_resolve_models.append(registry._registry_async_string_map[each])
to_resolve_object_ids = list()
for each in to_resolve:
if isinstance(each, LazyReference):
Expand Down Expand Up @@ -219,7 +222,10 @@ async def reference_resolver_async(root, *args, **kwargs):
querying_union_types.remove('__typename')
to_resolve_models = list()
for each in querying_union_types:
to_resolve_models.append(registry._registry_string_map[each])
if executor == ExecutorEnum.SYNC:
to_resolve_models.append(registry._registry_string_map[each])
else:
to_resolve_models.append(registry._registry_async_string_map[each])
to_resolve_object_ids = list()
for each in to_resolve:
if isinstance(each, LazyReference):
Expand Down Expand Up @@ -297,7 +303,7 @@ def convert_field_to_union(field, registry=None, executor: ExecutorEnum = Execut
elif isinstance(field, mongoengine.GenericEmbeddedDocumentField):
_field = mongoengine.EmbeddedDocumentField(choice)

_field = convert_mongoengine_field(_field, registry)
_field = convert_mongoengine_field(_field, registry, executor=executor)
_type = _field.get_type()
if _type:
_types.append(_type.type)
Expand All @@ -311,7 +317,7 @@ def convert_field_to_union(field, registry=None, executor: ExecutorEnum = Execut
name = to_camel_case("{}_{}".format(
field._owner_document.__name__,
field.db_field
)) + "UnionType"
)) + "UnionType" if ExecutorEnum.SYNC else "AsyncUnionType"
Meta = type("Meta", (object,), {"types": tuple(_types)})
_union = type(name, (graphene.Union,), {"Meta": Meta})

Expand All @@ -320,7 +326,7 @@ def reference_resolver(root, *args, **kwargs):
if de_referenced:
document = get_document(de_referenced["_cls"])
document_field = mongoengine.ReferenceField(document)
document_field = convert_mongoengine_field(document_field, registry)
document_field = convert_mongoengine_field(document_field, registry, executor=executor)
_type = document_field.get_type().type
filter_args = list()
if _type._meta.filter_fields:
Expand All @@ -344,7 +350,7 @@ def lazy_reference_resolver(root, *args, **kwargs):
document = getattr(root, field.name or field.db_name)
if document:
queried_fields = list()
document_field_type = registry.get_type_for_model(document.document_type)
document_field_type = registry.get_type_for_model(document.document_type, executor=executor)
querying_types = list(get_query_fields(args[0]).keys())
filter_args = list()
if document_field_type._meta.filter_fields:
Expand All @@ -356,7 +362,7 @@ def lazy_reference_resolver(root, *args, **kwargs):
item = to_snake_case(each)
if item in document.document_type._fields_ordered + tuple(filter_args):
queried_fields.append(item)
_type = registry.get_type_for_model(document.document_type)
_type = registry.get_type_for_model(document.document_type, executor=executor)
return document.document_type.objects().no_dereference().only(
*(set((list(_type._meta.required_fields) + queried_fields)))).get(
pk=document.pk)
Expand Down Expand Up @@ -393,7 +399,7 @@ async def lazy_reference_resolver_async(root, *args, **kwargs):
document = getattr(root, field.name or field.db_name)
if document:
queried_fields = list()
document_field_type = registry.get_type_for_model(document.document_type)
document_field_type = registry.get_type_for_model(document.document_type, executor=executor)
querying_types = list(get_query_fields(args[0]).keys())
filter_args = list()
if document_field_type._meta.filter_fields:
Expand All @@ -405,7 +411,7 @@ async def lazy_reference_resolver_async(root, *args, **kwargs):
item = to_snake_case(each)
if item in document.document_type._fields_ordered + tuple(filter_args):
queried_fields.append(item)
_type = registry.get_type_for_model(document.document_type)
_type = registry.get_type_for_model(document.document_type, executor=executor)
return await sync_to_async(document.document_type.objects().no_dereference().only(
*(set((list(_type._meta.required_fields) + queried_fields)))).get, thread_sensitive=False,
executor=ThreadPoolExecutor())(pk=document.pk)
Expand All @@ -418,7 +424,8 @@ async def lazy_reference_resolver_async(root, *args, **kwargs):
required = False
if field.db_field is not None:
required = field.required
resolver_function = getattr(registry.get_type_for_model(field.owner_document), "resolve_" + field.db_field,
resolver_function = getattr(registry.get_type_for_model(field.owner_document, executor=executor),
"resolve_" + field.db_field,
None)
if resolver_function and callable(resolver_function):
field_resolver = resolver_function
Expand All @@ -431,7 +438,8 @@ async def lazy_reference_resolver_async(root, *args, **kwargs):
required = False
if field.db_field is not None:
required = field.required
resolver_function = getattr(registry.get_type_for_model(field.owner_document), "resolve_" + field.db_field,
resolver_function = getattr(registry.get_type_for_model(field.owner_document, executor=executor),
"resolve_" + field.db_field,
None)
if resolver_function and callable(resolver_function):
field_resolver = resolver_function
Expand All @@ -452,7 +460,7 @@ def reference_resolver(root, *args, **kwargs):
document = getattr(root, field.name or field.db_name)
if document:
queried_fields = list()
_type = registry.get_type_for_model(field.document_type)
_type = registry.get_type_for_model(field.document_type, executor=executor)
filter_args = list()
if _type._meta.filter_fields:
for key, values in _type._meta.filter_fields.items():
Expand All @@ -470,7 +478,7 @@ def reference_resolver(root, *args, **kwargs):
def cached_reference_resolver(root, *args, **kwargs):
if field:
queried_fields = list()
_type = registry.get_type_for_model(field.document_type)
_type = registry.get_type_for_model(field.document_type, executor=executor)
filter_args = list()
if _type._meta.filter_fields:
for key, values in _type._meta.filter_fields.items():
Expand All @@ -490,7 +498,7 @@ async def reference_resolver_async(root, *args, **kwargs):
document = getattr(root, field.name or field.db_name)
if document:
queried_fields = list()
_type = registry.get_type_for_model(field.document_type)
_type = registry.get_type_for_model(field.document_type, executor=executor)
filter_args = list()
if _type._meta.filter_fields:
for key, values in _type._meta.filter_fields.items():
Expand All @@ -508,7 +516,7 @@ async def reference_resolver_async(root, *args, **kwargs):
async def cached_reference_resolver_async(root, *args, **kwargs):
if field:
queried_fields = list()
_type = registry.get_type_for_model(field.document_type)
_type = registry.get_type_for_model(field.document_type, executor=executor)
filter_args = list()
if _type._meta.filter_fields:
for key, values in _type._meta.filter_fields.items():
Expand All @@ -526,7 +534,7 @@ async def cached_reference_resolver_async(root, *args, **kwargs):
return None

def dynamic_type():
_type = registry.get_type_for_model(model)
_type = registry.get_type_for_model(model, executor=executor)
if not _type:
return None
if isinstance(field, mongoengine.EmbeddedDocumentField):
Expand All @@ -536,7 +544,8 @@ def dynamic_type():
required = False
if field.db_field is not None:
required = field.required
resolver_function = getattr(registry.get_type_for_model(field.owner_document), "resolve_" + field.db_field,
resolver_function = getattr(registry.get_type_for_model(field.owner_document, executor=executor),
"resolve_" + field.db_field,
None)
if resolver_function and callable(resolver_function):
field_resolver = resolver_function
Expand All @@ -560,7 +569,7 @@ def lazy_resolver(root, *args, **kwargs):
document = getattr(root, field.name or field.db_name)
if document:
queried_fields = list()
_type = registry.get_type_for_model(document.document_type)
_type = registry.get_type_for_model(document.document_type, executor=executor)
filter_args = list()
if _type._meta.filter_fields:
for key, values in _type._meta.filter_fields.items():
Expand All @@ -579,7 +588,7 @@ async def lazy_resolver_async(root, *args, **kwargs):
document = getattr(root, field.name or field.db_name)
if document:
queried_fields = list()
_type = registry.get_type_for_model(document.document_type)
_type = registry.get_type_for_model(document.document_type, executor=executor)
filter_args = list()
if _type._meta.filter_fields:
for key, values in _type._meta.filter_fields.items():
Expand All @@ -595,14 +604,15 @@ async def lazy_resolver_async(root, *args, **kwargs):
return None

def dynamic_type():
_type = registry.get_type_for_model(model)
_type = registry.get_type_for_model(model, executor=executor)
if not _type:
return None
field_resolver = None
required = False
if field.db_field is not None:
required = field.required
resolver_function = getattr(registry.get_type_for_model(field.owner_document), "resolve_" + field.db_field,
resolver_function = getattr(registry.get_type_for_model(field.owner_document, executor=executor),
"resolve_" + field.db_field,
None)
if resolver_function and callable(resolver_function):
field_resolver = resolver_function
Expand Down
6 changes: 0 additions & 6 deletions graphene_mongo/fields_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,6 @@

class AsyncMongoengineConnectionField(MongoengineConnectionField):
def __init__(self, type, *args, **kwargs):
get_queryset = kwargs.pop("get_queryset", None)
if get_queryset:
assert callable(
get_queryset
), "Attribute `get_queryset` on {} must be callable.".format(self)
self._get_queryset = get_queryset
super(AsyncMongoengineConnectionField, self).__init__(type, *args, **kwargs)

@property
Expand Down
19 changes: 15 additions & 4 deletions graphene_mongo/registry.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from graphene import Enum

from graphene_mongo.utils import ExecutorEnum


class Registry(object):
def __init__(self):
self._registry = {}
self._registry_async = {}
self._registry_string_map = {}
self._registry_async_string_map = {}
self._registry_enum = {}

def register(self, cls):
Expand All @@ -21,8 +25,12 @@ def register(self, cls):
cls.__name__
)
assert cls._meta.registry == self, "Registry for a Model have to match."
self._registry[cls._meta.model] = cls
self._registry_string_map[cls.__name__] = cls._meta.model.__name__
if issubclass(cls, GrapheneMongoengineObjectTypes):
self._registry[cls._meta.model] = cls
self._registry_string_map[cls.__name__] = cls._meta.model.__name__
else:
self._registry_async[cls._meta.model] = cls
self._registry_async_string_map[cls.__name__] = cls._meta.model.__name__

# Rescan all fields
for model, cls in self._registry.items():
Expand All @@ -40,8 +48,11 @@ def register_enum(self, cls):
cls.__name__ = name
self._registry_enum[cls] = Enum.from_enum(cls)

def get_type_for_model(self, model):
return self._registry.get(model)
def get_type_for_model(self, model, executor: ExecutorEnum = ExecutorEnum.SYNC):
if executor == ExecutorEnum.SYNC:
return self._registry.get(model)
else:
return self._registry_async.get(model)

def check_enum_already_exist(self, cls):
return cls in self._registry_enum
Expand Down
19 changes: 10 additions & 9 deletions graphene_mongo/tests/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from . import models
from . import types # noqa: F401
from .models import ProfessorMetadata
from ..types_async import AsyncMongoengineObjectType
from ..types import MongoengineObjectType

Expand Down Expand Up @@ -48,12 +49,6 @@ class Meta:
interfaces = (Node,)


class ReporterNodeAsync(AsyncMongoengineObjectType):
class Meta:
model = models.Reporter
interfaces = (Node,)


class ParentNode(MongoengineObjectType):
class Meta:
model = models.Parent
Expand All @@ -72,16 +67,22 @@ class Meta:
interfaces = (Node,)


class ChildRegisteredAfterNode(MongoengineObjectType):
class Meta:
model = models.ChildRegisteredAfter
interfaces = (Node,)


class ParentWithRelationshipNode(MongoengineObjectType):
class Meta:
model = models.ParentWithRelationship
interfaces = (Node,)


class ChildRegisteredAfterNode(MongoengineObjectType):
class ProfessorMetadataNode(MongoengineObjectType):
class Meta:
model = models.ChildRegisteredAfter
interfaces = (Node,)
model = ProfessorMetadata
interfaces = (graphene.Node,)


class ProfessorVectorNode(MongoengineObjectType):
Expand Down
Loading

0 comments on commit 2fe5008

Please sign in to comment.