From 09066a74524fccccf53557ca0a0d781aad0ed9e0 Mon Sep 17 00:00:00 2001 From: Arun Suresh Kumar Date: Wed, 12 Apr 2023 10:22:24 +0530 Subject: [PATCH] Blocking Threaded to Async --- graphene_mongo/fields_async.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/graphene_mongo/fields_async.py b/graphene_mongo/fields_async.py index 2203f939..2115da69 100644 --- a/graphene_mongo/fields_async.py +++ b/graphene_mongo/fields_async.py @@ -1,5 +1,5 @@ from __future__ import absolute_import - +from collections.abc import Iterable from functools import partial from typing import Coroutine @@ -107,7 +107,7 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non items = items[skip:skip + limit] elif skip: items = items[skip:] - iterables = items + iterables = await sync_to_async(list, thread_sensitive=False, executor=ThreadPoolExecutor())(items) list_length = len(iterables) elif callable(getattr(self.model, "objects", None)): @@ -145,6 +145,8 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non skip, limit, reverse = find_skip_and_limit(first=first, after=after, last=last, before=before, count=count) iterables = self.get_queryset(self.model, info, required_fields, skip, limit, reverse, **args) + iterables = await sync_to_async(list, thread_sensitive=False, executor=ThreadPoolExecutor())( + iterables) list_length = len(iterables) if isinstance(info, GraphQLResolveInfo): if not info.context: @@ -163,6 +165,8 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non elif skip: args["pk__in"] = args["pk__in"][skip:] iterables = self.get_queryset(self.model, info, required_fields, **args) + iterables = await sync_to_async(list, thread_sensitive=False, executor=ThreadPoolExecutor())( + iterables) list_length = len(iterables) if isinstance(info, GraphQLResolveInfo): if not info.context: @@ -183,12 +187,15 @@ async def default_resolver(self, _root, info, required_fields=None, resolved=Non elif skip: items = items[skip:] iterables = items + iterables = await sync_to_async(list, thread_sensitive=False, executor=ThreadPoolExecutor())( + iterables) list_length = len(iterables) has_next_page = True if (0 if limit is None else limit) + (0 if skip is None else skip) < count else False has_previous_page = True if skip else False if reverse: - iterables = list(iterables) + iterables = await sync_to_async(list, thread_sensitive=False, executor=ThreadPoolExecutor())( + iterables) iterables.reverse() skip = limit connection = connection_from_iterables(edges=iterables, start_offset=skip, @@ -296,8 +303,7 @@ async def connection_resolver(cls, resolver, connection_type, root, info, **args iterable = await resolver(root, info, **args) if isinstance(connection_type, graphene.NonNull): connection_type = connection_type.of_type + on_resolve = partial(cls.resolve_connection, connection_type, args) if Promise.is_thenable(iterable): - on_resolve = partial(cls.resolve_connection, connection_type, args) iterable = Promise.resolve(iterable).then(on_resolve).value - return await sync_to_async(cls.resolve_connection, thread_sensitive=False, - executor=ThreadPoolExecutor())(connection_type, args, iterable) + return on_resolve(iterable)