Skip to content

Commit

Permalink
refact: add typings
Browse files Browse the repository at this point in the history
  • Loading branch information
mak626 committed Nov 25, 2023
1 parent 11d0c8b commit c87361c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 118 deletions.
20 changes: 10 additions & 10 deletions graphene_mongo/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import bson
import graphene
import mongoengine
import pymongo
from bson import DBRef, ObjectId
from graphene import Context
from graphene.relay import ConnectionField
Expand All @@ -17,29 +18,28 @@
from graphene.types.utils import get_type
from graphene.utils.str_converters import to_snake_case
from graphql import GraphQLResolveInfo
from graphql_relay import from_global_id, cursor_to_offset
from graphql_relay import cursor_to_offset, from_global_id
from mongoengine import QuerySet
from mongoengine.base import get_document
from promise import Promise
from pymongo.errors import OperationFailure

from .advanced_types import (
FileFieldType,
PointFieldType,
MultiPolygonFieldType,
PolygonFieldType,
PointFieldInputType,
PointFieldType,
PolygonFieldType,
)
from .converter import convert_mongoengine_field, MongoEngineConversionError
from .converter import MongoEngineConversionError, convert_mongoengine_field
from .registry import get_global_registry
from .utils import (
ExecutorEnum,
connection_from_iterables,
find_skip_and_limit,
get_model_reference_fields,
get_query_fields,
find_skip_and_limit,
connection_from_iterables,
ExecutorEnum,
)
import pymongo

PYMONGO_VERSION = tuple(pymongo.version_tuple[:2])

Expand All @@ -55,7 +55,7 @@ def __init__(self, type, *args, **kwargs):
super(MongoengineConnectionField, self).__init__(type, *args, **kwargs)

@property
def executor(self):
def executor(self) -> ExecutorEnum:
return ExecutorEnum.SYNC

@property
Expand Down Expand Up @@ -277,7 +277,7 @@ def fields(self):

def get_queryset(
self, model, info, required_fields=None, skip=None, limit=None, reversed=False, **args
):
) -> QuerySet:
if required_fields is None:
required_fields = list()

Expand Down
122 changes: 14 additions & 108 deletions graphene_mongo/fields_async.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,33 @@
from __future__ import absolute_import

from functools import partial
from typing import Coroutine
from itertools import filterfalse
from typing import Coroutine

import bson
import graphene
import mongoengine
import pymongo
from bson import DBRef, ObjectId
from graphene import Context
from graphene.relay import ConnectionField
from graphene.utils.str_converters import to_snake_case
from graphql import GraphQLResolveInfo
from graphql_relay import from_global_id, cursor_to_offset
from graphql_relay import cursor_to_offset, from_global_id
from mongoengine import QuerySet
from mongoengine.base import get_document
from promise import Promise
from pymongo.errors import OperationFailure
from .registry import get_global_async_registry

from . import MongoengineConnectionField
from .registry import get_global_async_registry
from .utils import (
get_query_fields,
find_skip_and_limit,
connection_from_iterables,
ExecutorEnum,
connection_from_iterables,
find_skip_and_limit,
get_query_fields,
sync_to_async,
get_model_reference_fields,
)
import pymongo

PYMONGO_VERSION = tuple(pymongo.version_tuple[:2])

Expand Down Expand Up @@ -60,101 +61,6 @@ def fields(self):
def registry(self):
return getattr(self.node_type._meta, "registry", get_global_async_registry())

async def get_queryset(
self, model, info, required_fields=None, skip=None, limit=None, reversed=False, **args
):
if required_fields is None:
required_fields = list()

if args:
reference_fields = get_model_reference_fields(self.model)
hydrated_references = {}
for arg_name, arg in args.copy().items():
if arg_name in reference_fields and not isinstance(
arg, mongoengine.base.metaclasses.TopLevelDocumentMetaclass
):
try:
reference_obj = reference_fields[arg_name].document_type(
pk=from_global_id(arg)[1]
)
except TypeError:
reference_obj = reference_fields[arg_name].document_type(pk=arg)
hydrated_references[arg_name] = reference_obj
elif arg_name in self.model._fields_ordered and isinstance(
getattr(self.model, arg_name), mongoengine.fields.GenericReferenceField
):
try:
reference_obj = get_document(
self.registry._registry_string_map[from_global_id(arg)[0]]
)(pk=from_global_id(arg)[1])
except TypeError:
reference_obj = get_document(arg["_cls"])(pk=arg["_ref"].id)
hydrated_references[arg_name] = reference_obj
elif "__near" in arg_name and isinstance(
getattr(self.model, arg_name.split("__")[0]), mongoengine.fields.PointField
):
location = args.pop(arg_name, None)
hydrated_references[arg_name] = location["coordinates"]
if (arg_name.split("__")[0] + "__max_distance") not in args:
hydrated_references[arg_name.split("__")[0] + "__max_distance"] = 10000
elif arg_name == "id":
hydrated_references["id"] = from_global_id(args.pop("id", None))[1]
args.update(hydrated_references)

if self._get_queryset:
queryset_or_filters = self._get_queryset(model, info, **args)
if isinstance(queryset_or_filters, mongoengine.QuerySet):
return queryset_or_filters
else:
args.update(queryset_or_filters)
if limit is not None:
if reversed:
if self.order_by:
order_by = self.order_by + ",-pk"
else:
order_by = "-pk"
return await sync_to_async(
model.objects(**args)
.no_dereference()
.only(*required_fields)
.order_by(order_by)
.skip(skip if skip else 0)
.limit
)(limit)
else:
return await sync_to_async(
model.objects(**args)
.no_dereference()
.only(*required_fields)
.order_by(self.order_by)
.skip(skip if skip else 0)
.limit
)(limit)
elif skip is not None:
if reversed:
if self.order_by:
order_by = self.order_by + ",-pk"
else:
order_by = "-pk"
return await sync_to_async(
model.objects(**args)
.no_dereference()
.only(*required_fields)
.order_by(order_by)
.skip
)(skip)
else:
return await sync_to_async(
model.objects(**args)
.no_dereference()
.only(*required_fields)
.order_by(self.order_by)
.skip
)(skip)
return await sync_to_async(
model.objects(**args).no_dereference().only(*required_fields).order_by
)(self.order_by)

async def default_resolver(self, _root, info, required_fields=None, resolved=None, **args):
if required_fields is None:
required_fields = list()
Expand Down Expand Up @@ -284,15 +190,15 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non
skip, limit, reverse = find_skip_and_limit(
first=first, after=after, last=last, before=before, count=count
)
iterables = await self.get_queryset(
iterables = self.get_queryset(
self.model, info, required_fields, skip, limit, reverse, **args
)
iterables = await sync_to_async(list)(iterables)
list_length = len(iterables)
if isinstance(info, GraphQLResolveInfo):
if not info.context:
info = info._replace(context=Context())
info.context.queryset = await self.get_queryset(
info.context.queryset = self.get_queryset(
self.model, info, required_fields, **args
)

Expand All @@ -308,13 +214,13 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non
args["pk__in"] = args["pk__in"][skip : skip + limit]
elif skip:
args["pk__in"] = args["pk__in"][skip:]
iterables = await self.get_queryset(self.model, info, required_fields, **args)
iterables = self.get_queryset(self.model, info, required_fields, **args)
iterables = await sync_to_async(list)(iterables)
list_length = len(iterables)
if isinstance(info, GraphQLResolveInfo):
if not info.context:
info = info._replace(context=Context())
info.context.queryset = await self.get_queryset(
info.context.queryset = self.get_queryset(
self.model, info, required_fields, **args
)

Expand Down Expand Up @@ -410,7 +316,7 @@ def filter_connection(x):
if isinstance(info, GraphQLResolveInfo):
if not info.context:
info = info._replace(context=Context())
info.context.queryset = await self.get_queryset(
info.context.queryset = self.get_queryset(
self.model, info, required_fields, **args_copy
)

Expand Down

0 comments on commit c87361c

Please sign in to comment.