diff --git a/graphene_mongo/fields.py b/graphene_mongo/fields.py index 59903f0..1e6bdf0 100644 --- a/graphene_mongo/fields.py +++ b/graphene_mongo/fields.py @@ -8,6 +8,7 @@ import bson import graphene import mongoengine +import pymongo from bson import DBRef, ObjectId from graphene import Context from graphene.relay import ConnectionField @@ -17,7 +18,7 @@ from graphene.types.utils import get_type from graphene.utils.str_converters import to_snake_case from graphql import GraphQLResolveInfo -from graphql_relay import from_global_id, cursor_to_offset +from graphql_relay import cursor_to_offset, from_global_id from mongoengine import QuerySet from mongoengine.base import get_document from promise import Promise @@ -25,21 +26,20 @@ from .advanced_types import ( FileFieldType, - PointFieldType, MultiPolygonFieldType, - PolygonFieldType, PointFieldInputType, + PointFieldType, + PolygonFieldType, ) -from .converter import convert_mongoengine_field, MongoEngineConversionError +from .converter import MongoEngineConversionError, convert_mongoengine_field from .registry import get_global_registry from .utils import ( + ExecutorEnum, + connection_from_iterables, + find_skip_and_limit, get_model_reference_fields, get_query_fields, - find_skip_and_limit, - connection_from_iterables, - ExecutorEnum, ) -import pymongo PYMONGO_VERSION = tuple(pymongo.version_tuple[:2]) @@ -55,7 +55,7 @@ def __init__(self, type, *args, **kwargs): super(MongoengineConnectionField, self).__init__(type, *args, **kwargs) @property - def executor(self): + def executor(self) -> ExecutorEnum: return ExecutorEnum.SYNC @property @@ -277,7 +277,7 @@ def fields(self): def get_queryset( self, model, info, required_fields=None, skip=None, limit=None, reversed=False, **args - ): + ) -> QuerySet: if required_fields is None: required_fields = list() diff --git a/graphene_mongo/fields_async.py b/graphene_mongo/fields_async.py index b930bcd..314623e 100644 --- a/graphene_mongo/fields_async.py +++ b/graphene_mongo/fields_async.py @@ -1,32 +1,33 @@ from __future__ import absolute_import + from functools import partial -from typing import Coroutine from itertools import filterfalse +from typing import Coroutine import bson import graphene import mongoengine +import pymongo from bson import DBRef, ObjectId from graphene import Context from graphene.relay import ConnectionField from graphene.utils.str_converters import to_snake_case from graphql import GraphQLResolveInfo -from graphql_relay import from_global_id, cursor_to_offset +from graphql_relay import cursor_to_offset, from_global_id from mongoengine import QuerySet from mongoengine.base import get_document from promise import Promise from pymongo.errors import OperationFailure -from .registry import get_global_async_registry + from . import MongoengineConnectionField +from .registry import get_global_async_registry from .utils import ( - get_query_fields, - find_skip_and_limit, - connection_from_iterables, ExecutorEnum, + connection_from_iterables, + find_skip_and_limit, + get_query_fields, sync_to_async, - get_model_reference_fields, ) -import pymongo PYMONGO_VERSION = tuple(pymongo.version_tuple[:2]) @@ -60,101 +61,6 @@ def fields(self): def registry(self): return getattr(self.node_type._meta, "registry", get_global_async_registry()) - async def get_queryset( - self, model, info, required_fields=None, skip=None, limit=None, reversed=False, **args - ): - if required_fields is None: - required_fields = list() - - if args: - reference_fields = get_model_reference_fields(self.model) - hydrated_references = {} - for arg_name, arg in args.copy().items(): - if arg_name in reference_fields and not isinstance( - arg, mongoengine.base.metaclasses.TopLevelDocumentMetaclass - ): - try: - reference_obj = reference_fields[arg_name].document_type( - pk=from_global_id(arg)[1] - ) - except TypeError: - reference_obj = reference_fields[arg_name].document_type(pk=arg) - hydrated_references[arg_name] = reference_obj - elif arg_name in self.model._fields_ordered and isinstance( - getattr(self.model, arg_name), mongoengine.fields.GenericReferenceField - ): - try: - reference_obj = get_document( - self.registry._registry_string_map[from_global_id(arg)[0]] - )(pk=from_global_id(arg)[1]) - except TypeError: - reference_obj = get_document(arg["_cls"])(pk=arg["_ref"].id) - hydrated_references[arg_name] = reference_obj - elif "__near" in arg_name and isinstance( - getattr(self.model, arg_name.split("__")[0]), mongoengine.fields.PointField - ): - location = args.pop(arg_name, None) - hydrated_references[arg_name] = location["coordinates"] - if (arg_name.split("__")[0] + "__max_distance") not in args: - hydrated_references[arg_name.split("__")[0] + "__max_distance"] = 10000 - elif arg_name == "id": - hydrated_references["id"] = from_global_id(args.pop("id", None))[1] - args.update(hydrated_references) - - if self._get_queryset: - queryset_or_filters = self._get_queryset(model, info, **args) - if isinstance(queryset_or_filters, mongoengine.QuerySet): - return queryset_or_filters - else: - args.update(queryset_or_filters) - if limit is not None: - if reversed: - if self.order_by: - order_by = self.order_by + ",-pk" - else: - order_by = "-pk" - return await sync_to_async( - model.objects(**args) - .no_dereference() - .only(*required_fields) - .order_by(order_by) - .skip(skip if skip else 0) - .limit - )(limit) - else: - return await sync_to_async( - model.objects(**args) - .no_dereference() - .only(*required_fields) - .order_by(self.order_by) - .skip(skip if skip else 0) - .limit - )(limit) - elif skip is not None: - if reversed: - if self.order_by: - order_by = self.order_by + ",-pk" - else: - order_by = "-pk" - return await sync_to_async( - model.objects(**args) - .no_dereference() - .only(*required_fields) - .order_by(order_by) - .skip - )(skip) - else: - return await sync_to_async( - model.objects(**args) - .no_dereference() - .only(*required_fields) - .order_by(self.order_by) - .skip - )(skip) - return await sync_to_async( - model.objects(**args).no_dereference().only(*required_fields).order_by - )(self.order_by) - async def default_resolver(self, _root, info, required_fields=None, resolved=None, **args): if required_fields is None: required_fields = list() @@ -284,7 +190,7 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non skip, limit, reverse = find_skip_and_limit( first=first, after=after, last=last, before=before, count=count ) - iterables = await self.get_queryset( + iterables = self.get_queryset( self.model, info, required_fields, skip, limit, reverse, **args ) iterables = await sync_to_async(list)(iterables) @@ -292,7 +198,7 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non if isinstance(info, GraphQLResolveInfo): if not info.context: info = info._replace(context=Context()) - info.context.queryset = await self.get_queryset( + info.context.queryset = self.get_queryset( self.model, info, required_fields, **args ) @@ -308,13 +214,13 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non args["pk__in"] = args["pk__in"][skip : skip + limit] elif skip: args["pk__in"] = args["pk__in"][skip:] - iterables = await self.get_queryset(self.model, info, required_fields, **args) + iterables = self.get_queryset(self.model, info, required_fields, **args) iterables = await sync_to_async(list)(iterables) list_length = len(iterables) if isinstance(info, GraphQLResolveInfo): if not info.context: info = info._replace(context=Context()) - info.context.queryset = await self.get_queryset( + info.context.queryset = self.get_queryset( self.model, info, required_fields, **args ) @@ -410,7 +316,7 @@ def filter_connection(x): if isinstance(info, GraphQLResolveInfo): if not info.context: info = info._replace(context=Context()) - info.context.queryset = await self.get_queryset( + info.context.queryset = self.get_queryset( self.model, info, required_fields, **args_copy )