From ce2ad90e9ac615d8a495495edbb0e90d305a66fa Mon Sep 17 00:00:00 2001 From: James Socol Date: Mon, 24 Jul 2023 15:47:37 -0400 Subject: [PATCH] Add support for async functions to decorator Tries to add support for async functions to the decorator, but trips over not having a failing test to fix. --- .github/actions/test/action.yml | 22 ++++++++++++++++------ django_ratelimit/decorators.py | 27 ++++++++++++++++++++++++++- django_ratelimit/tests.py | 30 +++++++++++++++++++++++++++++- 3 files changed, 71 insertions(+), 8 deletions(-) diff --git a/.github/actions/test/action.yml b/.github/actions/test/action.yml index 73d6392..90197ac 100644 --- a/.github/actions/test/action.yml +++ b/.github/actions/test/action.yml @@ -14,13 +14,23 @@ runs: with: python-version: ${{ inputs.python-version }} - - name: Install dependencies + - name: Update pip shell: sh - run: | - python -m pip install --upgrade pip - if [[ ${{ inputs.django-version }} != 'main' ]]; then pip install --pre -q "Django>=${{ inputs.django-version }},<${{ inputs.django-version }}.99"; fi - if [[ ${{ inputs.django-version }} == 'main' ]]; then pip install https://github.com/django/django/archive/main.tar.gz; fi - pip install flake8 django-redis pymemcache + run: python -m pip install --upgrade pip + + - name: Install Django + shell: sh + run: python -m pip install "Django>=${{ inputs.django-version }},<${{ inputs.django-version }}.99" + if: ${{ inputs.django-version != 'main' }} + + - name: Install Django main + shell: sh + run: python -m pip install https://github.com/django/django/archive/main.tar.gz + if: ${{ inputs.django-version == 'main' }} + + - name: Install Django dependencies + shell: sh + run: pip install flake8 django-redis pymemcache - name: Test shell: sh diff --git a/django_ratelimit/decorators.py b/django_ratelimit/decorators.py index 40c9541..0d50cea 100644 --- a/django_ratelimit/decorators.py +++ b/django_ratelimit/decorators.py @@ -1,4 +1,10 @@ from functools import wraps +import django +if django.VERSION >= (4, 1): + from asgiref.sync import iscoroutinefunction +else: + def iscoroutinefunction(func): + return False from django.conf import settings from django.utils.module_loading import import_string @@ -13,6 +19,23 @@ def ratelimit(group=None, key=None, rate=None, method=ALL, block=True): def decorator(fn): + # if iscoroutinefunction(fn): + # @wraps(fn) + # async def _async_wrapped(request, *args, **kw): + # old_limited = getattr(request, 'limited', False) + # ratelimited = is_ratelimited( + # request=request, group=group, fn=fn, key=key, rate=rate, + # method=method, increment=True) + # request.limited = ratelimited or old_limited + # if ratelimited and block: + # cls = getattr( + # settings, 'RATELIMIT_EXCEPTION_CLASS', Ratelimited) + # if isinstance(cls, str): + # cls = import_string(cls) + # raise cls() + # return await fn(request, *args, **kw) + # return _async_wrapped + @wraps(fn) def _wrapped(request, *args, **kw): old_limited = getattr(request, 'limited', False) @@ -23,7 +46,9 @@ def _wrapped(request, *args, **kw): if ratelimited and block: cls = getattr( settings, 'RATELIMIT_EXCEPTION_CLASS', Ratelimited) - raise (import_string(cls) if isinstance(cls, str) else cls)() + if isinstance(cls, str): + cls = import_string(cls) + raise cls() return fn(request, *args, **kw) return _wrapped return decorator diff --git a/django_ratelimit/tests.py b/django_ratelimit/tests.py index a58c89e..639c9ff 100644 --- a/django_ratelimit/tests.py +++ b/django_ratelimit/tests.py @@ -1,3 +1,6 @@ +import asyncio + +import django from functools import partial from django.core.cache import cache, InvalidCacheBackendError @@ -12,7 +15,10 @@ from django_ratelimit.core import (get_usage, is_ratelimited, _split_rate, _get_ip) - +if django.VERSION >= (4, 1): + from asgiref.sync import iscoroutinefunction + from django.test import AsyncRequestFactory + arf = AsyncRequestFactory() rf = RequestFactory() @@ -412,6 +418,28 @@ def view(request): assert not view(req) +if django.VERSION >= (4, 1): + class AsyncTests(TestCase): + def setUp(self): + cache.clear() + + async def test_decorate_async_function(self): + @ratelimit(key='ip', rate='1/m', block=False) + async def view(request): + await asyncio.sleep(0) + return request.limited + + req1 = arf.get('/') + req1.META['REMOTE_ADDR'] = '1.2.3.4' + + req2 = arf.get('/') + req2.META['REMOTE_ADDR'] = '1.2.3.4' + + assert iscoroutinefunction(view) + assert await view(req1) is False + assert await view(req2) is True + + class FunctionsTests(TestCase): def setUp(self): cache.clear()