diff --git a/django_ratelimit/core.py b/django_ratelimit/core.py index 1270799..694125d 100644 --- a/django_ratelimit/core.py +++ b/django_ratelimit/core.py @@ -14,7 +14,7 @@ from django_ratelimit import ALL, UNSAFE -__all__ = ['is_ratelimited', 'get_usage'] +__all__ = ['is_ratelimited', 'ais_ratelimited', 'get_usage', 'aget_usage'] _PERIODS = { 's': 1, @@ -156,9 +156,30 @@ def is_ratelimited(request, group=None, fn=None, key=None, rate=None, return usage['should_limit'] +async def ais_ratelimited(request, group=None, fn=None, key=None, rate=None, + method=ALL, increment=False): + usage = await aget_usage(request, group, fn, key, rate, method, increment) + if usage is None: + return False + + return usage['should_limit'] + def get_usage(request, group=None, fn=None, key=None, rate=None, method=ALL, increment=False): + usage = _get_usage(request, group, fn, key, rate, method, increment) + if usage is not None: + return usage() + +async def aget_usage(request, group=None, fn=None, key=None, rate=None, method=ALL, + increment=False): + usage = _get_usage(request, group, fn, key, rate, method, increment, is_async=True) + if usage is not None: + return await usage() + + +def _get_usage(request, group=None, fn=None, key=None, rate=None, method=ALL, + increment=False, is_async=False): if group is None and fn is None: raise ImproperlyConfigured('get_usage must be called with either ' '`group` or `fn` arguments') @@ -227,45 +248,102 @@ def get_usage(request, group=None, fn=None, key=None, rate=None, method=ALL, cache_key = _make_cache_key(group, window, rate, value, method) count = None - try: - added = cache.add(cache_key, initial_value, period + EXPIRATION_FUDGE) - except socket.gaierror: # for redis - added = False - if added: - count = initial_value + if is_async: + async def inner(): + try: + # Some caches don't have an async implementation + if hasattr(cache, 'aadd'): + added = await cache.aadd(cache_key, initial_value, period + EXPIRATION_FUDGE) + else: + added = cache.add(cache_key, initial_value, period + EXPIRATION_FUDGE) + except socket.gaierror: # for redis + added = False + if added: + count = initial_value + else: + if increment: + try: + # python3-memcached will throw a ValueError if the server is + # unavailable or (somehow) the key doesn't exist. redis, on the + # other hand, simply returns None. + if hasattr(cache, 'aincr'): + count = await cache.aincr(cache_key) + else: + count = cache.incr(cache_key) + except ValueError: + pass + else: + if hasattr(cache, 'aget'): + count = await cache.aget(cache_key, initial_value) + else: + count = cache.get(cache_key, initial_value) + + # Getting or setting the count from the cache failed + if count is None or count is False: + if getattr(settings, 'RATELIMIT_FAIL_OPEN', False): + return None + return { + 'count': 0, + 'limit': 0, + 'should_limit': True, + 'time_left': -1, + } + + time_left = window - int(time.time()) + return { + 'count': count, + 'limit': limit, + 'should_limit': count > limit, + 'time_left': time_left, + } else: - if increment: + def inner(): try: - # python3-memcached will throw a ValueError if the server is - # unavailable or (somehow) the key doesn't exist. redis, on the - # other hand, simply returns None. - count = cache.incr(cache_key) - except ValueError: - pass - else: - count = cache.get(cache_key, initial_value) - - # Getting or setting the count from the cache failed - if count is None or count is False: - if getattr(settings, 'RATELIMIT_FAIL_OPEN', False): - return None - return { - 'count': 0, - 'limit': 0, - 'should_limit': True, - 'time_left': -1, - } - - time_left = window - int(time.time()) - return { - 'count': count, - 'limit': limit, - 'should_limit': count > limit, - 'time_left': time_left, - } + added = cache.add(cache_key, initial_value, period + EXPIRATION_FUDGE) + except socket.gaierror: # for redis + added = False + if added: + count = initial_value + else: + if increment: + try: + # python3-memcached will throw a ValueError if the server is + # unavailable or (somehow) the key doesn't exist. redis, on the + # other hand, simply returns None. + count = cache.incr(cache_key) + except ValueError: + pass + else: + count = cache.get(cache_key, initial_value) + + # Getting or setting the count from the cache failed + if count is None or count is False: + if getattr(settings, 'RATELIMIT_FAIL_OPEN', False): + return None + return { + 'count': 0, + 'limit': 0, + 'should_limit': True, + 'time_left': -1, + } + + time_left = window - int(time.time()) + return { + 'count': count, + 'limit': limit, + 'should_limit': count > limit, + 'time_left': time_left, + } + + return inner + is_ratelimited.ALL = ALL is_ratelimited.UNSAFE = UNSAFE +ais_ratelimited.ALL = ALL +ais_ratelimited.UNSAFE = UNSAFE get_usage.ALL = ALL get_usage.UNSAFE = UNSAFE +aget_usage.ALL = ALL +aget_usage.UNSAFE = UNSAFE diff --git a/django_ratelimit/decorators.py b/django_ratelimit/decorators.py index 40c9541..2db2a18 100644 --- a/django_ratelimit/decorators.py +++ b/django_ratelimit/decorators.py @@ -5,7 +5,8 @@ from django_ratelimit import ALL, UNSAFE from django_ratelimit.exceptions import Ratelimited -from django_ratelimit.core import is_ratelimited +from django_ratelimit.core import is_ratelimited, ais_ratelimited +from asgiref.sync import iscoroutinefunction __all__ = ['ratelimit'] @@ -13,21 +14,35 @@ def ratelimit(group=None, key=None, rate=None, method=ALL, block=True): def decorator(fn): - @wraps(fn) - def _wrapped(request, *args, **kw): - old_limited = getattr(request, 'limited', False) - ratelimited = is_ratelimited(request=request, group=group, fn=fn, - key=key, rate=rate, method=method, - increment=True) - request.limited = ratelimited or old_limited - if ratelimited and block: - cls = getattr( - settings, 'RATELIMIT_EXCEPTION_CLASS', Ratelimited) - raise (import_string(cls) if isinstance(cls, str) else cls)() - return fn(request, *args, **kw) + if iscoroutinefunction(fn): + @wraps(fn) + async def _wrapped(request, *args, **kw): + old_limited = getattr(request, 'limited', False) + ratelimited = await ais_ratelimited(request=request, group=group, fn=fn, + key=key, rate=rate, method=method, + increment=True) + request.limited = ratelimited or old_limited + if ratelimited and block: + cls = getattr( + settings, 'RATELIMIT_EXCEPTION_CLASS', Ratelimited) + raise (import_string(cls) if isinstance(cls, str) else cls)() + return await fn(request, *args, **kw) + else: + @wraps(fn) + def _wrapped(request, *args, **kw): + old_limited = getattr(request, 'limited', False) + ratelimited = is_ratelimited(request=request, group=group, fn=fn, + key=key, rate=rate, method=method, + increment=True) + request.limited = ratelimited or old_limited + if ratelimited and block: + cls = getattr( + settings, 'RATELIMIT_EXCEPTION_CLASS', Ratelimited) + raise (import_string(cls) if isinstance(cls, str) else cls)() + return fn(request, *args, **kw) return _wrapped - return decorator + return decorator ratelimit.ALL = ALL ratelimit.UNSAFE = UNSAFE diff --git a/django_ratelimit/tests/__init__.py b/django_ratelimit/tests/__init__.py new file mode 100644 index 0000000..bbe2037 --- /dev/null +++ b/django_ratelimit/tests/__init__.py @@ -0,0 +1,14 @@ + +def my_ip(req): + return req.META['MY_THING'] + +def callable_rate(group, request): + if request.user.is_authenticated: + return None + return (0, 1) + +def mykey(group, request): + return request.META['REMOTE_ADDR'][::-1] + +class CustomRatelimitedException(Exception): + pass \ No newline at end of file diff --git a/django_ratelimit/tests/test_async.py b/django_ratelimit/tests/test_async.py new file mode 100644 index 0000000..4403d05 --- /dev/null +++ b/django_ratelimit/tests/test_async.py @@ -0,0 +1,620 @@ + +from django.core.cache import cache, InvalidCacheBackendError +from django.core.exceptions import ImproperlyConfigured +from django.test import RequestFactory, TestCase +from django.test.utils import override_settings +from django.utils.decorators import method_decorator +from django.views.generic import View + +from django_ratelimit.decorators import ratelimit +from django_ratelimit.exceptions import Ratelimited +from django_ratelimit.core import (aget_usage, ais_ratelimited, + _split_rate, _get_ip) + +from . import my_ip, CustomRatelimitedException, mykey + +def async_partial(f, **kwargs): + async def inner(*args, **kwargs_inner): + return await f(*args, **kwargs, **kwargs_inner) + + return inner + +rf = RequestFactory() + +class MockUser: + def __init__(self, authenticated=False): + self.pk = 1 + self.is_authenticated = authenticated + + +class RateParsingTests(TestCase): + def test_simple(self): + tests = ( + ('100/s', (100, 1)), + ('100/10s', (100, 10)), + ('100/10', (100, 10)), + ('100/m', (100, 60)), + ('400/10m', (400, 600)), + ('1000/h', (1000, 3600)), + ('800/d', (800, 24 * 60 * 60)), + ) + + for i, o in tests: + assert o == _split_rate(i) + + +class RatelimitAsyncTests(TestCase): + def setUp(self): + cache.clear() + + async def test_no_key(self): + @ratelimit(rate='1/m') + async def view(request): + return True + + req = rf.get('/') + with self.assertRaises(ImproperlyConfigured): + await view(req) + + async def test_ip(self): + @ratelimit(key='ip', rate='1/m', block=False) + async def view(request): + return request.limited + + assert not await view(rf.get('/')), 'First request works.' + assert await view(rf.get('/')), 'Second request is limited' + + async def test_ip_async(self): + @ratelimit(key='ip', rate='1/m', block=False) + async def view(request): + return request.limited + + assert not await view(rf.get('/')), 'First request works.' + assert await view(rf.get('/')), 'Second request is limited' + + async def test_block(self): + @ratelimit(key='ip', rate='1/m') + async def blocked(request): + return request.limited + + assert not await blocked(rf.get('/')), 'First request works.' + with self.assertRaises(Ratelimited): + await blocked(rf.get('/')), 'Second request is blocked.' + + async def test_ratelimit_custom_string_exception_class(self): + @ratelimit(key='ip', rate='1/m') + async def view(request): + return request.limited + + with self.settings( + RATELIMIT_EXCEPTION_CLASS=( + "django_ratelimit.tests.CustomRatelimitedException" + ) + ): + req = rf.get("") + assert not await view(req) + with self.assertRaises(CustomRatelimitedException): + await view(req) + + async def test_ratelimit_custom_exception_class(self): + @ratelimit(key='ip', rate='1/m') + async def view(request): + return request.limited + + with self.settings( + RATELIMIT_EXCEPTION_CLASS=CustomRatelimitedException + ): + req = rf.get("") + assert not await view(req) + with self.assertRaises(CustomRatelimitedException): + await view(req) + + async def test_method(self): + @ratelimit(key='ip', method='POST', rate='1/m', group='a', block=False) + async def limit_post(request): + return request.limited + + assert not await limit_post(rf.post('/')), 'Do not limit first POST.' + assert await limit_post(rf.post('/')), 'Limit second POST.' + assert not await limit_post(rf.get('/')), 'Do not limit GET.' + + async def test_unsafe_methods(self): + @ratelimit(key='ip', method=ratelimit.UNSAFE, rate='0/m', block=False) + async def limit_unsafe(request): + return request.limited + + assert not await limit_unsafe(rf.get('/')) + assert not await limit_unsafe(rf.head('/')) + assert not await limit_unsafe(rf.options('/')) + assert await limit_unsafe(rf.delete('/')) + assert await limit_unsafe(rf.post('/')) + assert await limit_unsafe(rf.put('/')) + assert await limit_unsafe(rf.patch('/')) + + async def test_key_get(self): + @ratelimit(key='get:foo', rate='1/m', method='GET', block=False) + async def view(request): + return request.limited + + assert not await view(rf.get('/', {'foo': 'a'})) + assert await view(rf.get('/', {'foo': 'a'})) + assert not await view(rf.get('/', {'foo': 'b'})) + assert await view(rf.get('/', {'foo': 'b'})) + + async def test_key_post(self): + @ratelimit(key='post:foo', rate='1/m', block=False) + async def view(request): + return request.limited + + assert not await view(rf.post('/', {'foo': 'a'})) + assert await view(rf.post('/', {'foo': 'a'})) + assert not await view(rf.post('/', {'foo': 'b'})) + assert await view(rf.post('/', {'foo': 'b'})) + + async def test_key_header(self): + def _req(): + req = rf.post('/') + req.META['HTTP_X_REAL_IP'] = '1.2.3.4' + return req + + @ratelimit(key='header:x-real-ip', rate='1/m', block=False) + @ratelimit(key='header:x-missing-header', rate='1/m', block=False) + async def view(request): + return request.limited + + assert not await view(_req()) + assert await view(_req()) + + async def test_rate(self): + @ratelimit(key='ip', rate='2/m', block=False) + async def twice(request): + return request.limited + + assert not await twice(rf.post('/')), 'First request is not limited.' + assert not await twice(rf.post('/')), 'Second request is not limited.' + assert await twice(rf.post('/')), 'Third request is limited.' + + async def test_zero_rate(self): + @ratelimit(key='ip', rate='0/m', block=False) + async def never(request): + return request.limited + + assert await never(rf.post('/')) + + async def test_none_rate(self): + @ratelimit(key='ip', rate=None, block=False) + async def always(request): + return request.limited + + assert not await always(rf.post('/')) + assert not await always(rf.post('/')) + assert not await always(rf.post('/')) + assert not await always(rf.post('/')) + assert not await always(rf.post('/')) + assert not await always(rf.post('/')) + assert not await always(rf.post('/')) + + async def test_callable_rate(self): + def _req(auth): + req = rf.post('/') + req.user = MockUser(authenticated=auth) + return req + + def get_rate(group, request): + if request.user.is_authenticated: + return (2, 60) + return (1, 60) + + @ratelimit(key='user_or_ip', rate=get_rate, block=False) + async def view(request): + return request.limited + + assert not await view(_req(auth=False)) + assert await view(_req(auth=False)) + assert not await view(_req(auth=True)) + assert not await view(_req(auth=True)) + assert await view(_req(auth=True)) + + async def test_callable_rate_none(self): + def _req(never_limit=False): + req = rf.post('/') + req.never_limit = never_limit + return req + + get_rate = lambda g, r: None if r.never_limit else '1/m' + + @ratelimit(key='ip', rate=get_rate, block=False) + async def view(request): + return request.limited + + assert not await view(_req()) + assert await view(_req()) + assert not await view(_req(never_limit=True)) + assert not await view(_req(never_limit=True)) + + async def test_callable_rate_zero(self): + def _req(auth): + req = rf.post('/') + req.user = MockUser(authenticated=auth) + return req + + def get_rate(group, request): + if request.user.is_authenticated: + return '1/m' + return '0/m' + + @ratelimit(key='ip', rate=get_rate, block=False) + async def view(request): + return request.limited + + assert await view(_req(auth=False)) + assert not await view(_req(auth=True)) + assert await view(_req(auth=True)) + + async def test_callable_rate_import(self): + def _req(auth): + req = rf.post('/') + req.user = MockUser(authenticated=auth) + return req + + @ratelimit(key='user_or_ip', + rate='django_ratelimit.tests.callable_rate', + block=False) + async def view(request): + return request.limited + + assert await view(_req(auth=False)) + assert not await view(_req(auth=True)) + + async def test_user_or_ip(self): + """Allow custom functions to set cache keys.""" + + def _req(auth): + req = rf.post('/') + req.user = MockUser(authenticated=auth) + return req + + @ratelimit(key='user_or_ip', rate='1/m', block=False) + async def view(request): + return request.limited + + assert not await view(_req(auth=False)) + assert await view(_req(auth=False)) + + auth = rf.post('/') + auth.user = MockUser(authenticated=True) + + assert not await view(_req(auth=True)) + assert await view(_req(auth=True)) + + async def test_callable_key_path(self): + @ratelimit(key='django_ratelimit.tests.mykey', rate='1/m', block=False) + async def view(request): + return request.limited + + assert not await view(rf.post('/')) + assert await view(rf.post('/')) + + async def test_callable_key(self): + @ratelimit(key=mykey, rate='1/m', block=False) + async def view(request): + return request.limited + + assert not await view(rf.post('/')) + assert await view(rf.post('/')) + + async def test_stacked_decorator(self): + """Allow @ratelimit to be stacked.""" + # Put the shorter one first and make sure the second one doesn't + # reset request.limited back to False. + @ratelimit(rate='1/m', block=False, key=lambda x, y: 'min') + @ratelimit(rate='10/d', block=False, key=lambda x, y: 'day') + async def view(request): + return request.limited + + assert not await view(rf.post('/')) + assert await view(rf.post('/')) + + async def test_stacked_methods(self): + """Different methods should result in different counts.""" + @ratelimit(rate='1/m', key='ip', method='GET', block=False) + @ratelimit(rate='1/m', key='ip', method='POST', block=False) + async def view(request): + return request.limited + + assert not await view(rf.get('/')) + assert not await view(rf.post('/')) + assert await view(rf.get('/')) + assert await view(rf.post('/')) + + async def test_sorted_methods(self): + """Order of the methods shouldn't matter.""" + @ratelimit(rate='1/m', key='ip', method=['GET', 'POST'], + group='a', block=False) + async def get_post(request): + return request.limited + + @ratelimit(rate='1/m', key='ip', method=['POST', 'GET'], + group='a', block=False) + async def post_get(request): + return request.limited + + assert not await get_post(rf.get('/')) + assert await post_get(rf.get('/')) + + async def test_ratelimit_full_mask_v4(self): + @ratelimit(rate='1/m', key='ip', block=False) + async def view(request): + return request.limited + + with self.settings(RATELIMIT_IPV4_MASK=32): + req = rf.get('/') + req.META['REMOTE_ADDR'] = '10.1.1.1' + assert not await view(req) + assert await view(req) + + req = rf.get('/') + req.META['REMOTE_ADDR'] = '10.1.1.2' + assert not await view(req) + + async def test_ratelimit_full_mask_v6(self): + @ratelimit(rate='1/m', key='ip', block=False) + async def view(request): + return request.limited + + with self.settings(RATELIMIT_IPV6_MASK=128): + req = rf.get('/') + req.META['REMOTE_ADDR'] = '2001:db8::1000' + assert not await view(req) + assert await view(req) + + req = rf.get('/') + req.META['REMOTE_ADDR'] = '2001:db8::1001' + assert not await view(req) + + async def test_ratelimit_mask_v4(self): + @ratelimit(rate='1/m', key='ip', block=False) + async def view(request): + return request.limited + + with self.settings(RATELIMIT_IPV4_MASK=16): + req = rf.get('/') + req.META['REMOTE_ADDR'] = '10.1.1.1' + assert not await view(req) + assert await view(req) + + req = rf.get('/') + req.META['REMOTE_ADDR'] = '10.1.0.1' + assert await view(req) + + req = rf.get('/') + req.META['REMOTE_ADDR'] = '192.168.1.1' + assert not await view(req) + + async def test_ratelimit_mask_v6(self): + @ratelimit(rate='1/m', key='ip', block=False) + async def view(request): + return request.limited + + with self.settings(RATELIMIT_IPV6_MASK=64): + req = rf.get('/') + req.META['REMOTE_ADDR'] = '2001:db8::1000' + assert not await view(req) + assert await view(req) + + req = rf.get('/') + req.META['REMOTE_ADDR'] = '2001:db8::1001' + assert await view(req) + + req = rf.get('/') + req.META['REMOTE_ADDR'] = '2001:db9::1000' + assert not await view(req) + + +class AsyncFunctionsTests(TestCase): + def setUp(self): + cache.clear() + + async def test_ais_ratelimited(self): + not_increment = async_partial(ais_ratelimited, increment=False, rate='1/m', + method=ais_ratelimited.ALL, key='ip', group='a') + + # Does not increment. Count still 0. Does not rate limit + # because 0 < 1. + assert not await not_increment(rf.get('/')) + + # Does not increment. Count still 1. Not limited because 1 > 1 + # is false. + assert not await not_increment(rf.get('/')) + + async def test_is_ratelimited_increment(self): + do_increment = async_partial(ais_ratelimited, increment=True, rate='1/m', + method=ais_ratelimited.ALL, key='ip', group='a') + + # Increments. Does not rate limit because 0 < 1. Count now 1. + assert not await do_increment(rf.get('/')) + + # Count = 2, 2 > 1. + assert await do_increment(rf.get('/')) + + async def test_aget_usage(self): + _get_usage = async_partial(aget_usage, method=aget_usage.ALL, key='ip', + rate='1/m', group='a') + usage = await _get_usage(rf.get('/')) + + self.assertEqual(usage['count'], 0) + self.assertEqual(usage['limit'], 1) + self.assertLessEqual(usage['time_left'], 60) + self.assertFalse(usage['should_limit']) + + async def test_aget_usage_increment(self): + _get_usage = async_partial(aget_usage, method=aget_usage.ALL, key='ip', + rate='1/m', group='a', increment=True) + await _get_usage(rf.get('/')) + usage = await _get_usage(rf.get('/')) + + self.assertEqual(usage['count'], 2) + self.assertEqual(usage['limit'], 1) + self.assertLessEqual(usage['time_left'], 60) + self.assertTrue(usage['should_limit']) + + async def test_not_increment_after_increment(self): + _get_usage = async_partial(aget_usage, method=aget_usage.ALL, key='ip', + rate='1/m', group='a') + await _get_usage(rf.get('/'), increment=True) + await _get_usage(rf.get('/'), increment=True) + usage = await _get_usage(rf.get('/')) + + self.assertEqual(usage['count'], 2) + self.assertEqual(usage['limit'], 1) + self.assertLessEqual(usage['time_left'], 60) + self.assertTrue(usage['should_limit']) + + async def test_get_usage_called_without_group_or_fn(self): + with self.assertRaises(ImproperlyConfigured): + await aget_usage(rf.get('/'), key='ip') + + +class RatelimitACBVTests(TestCase): + def setUp(self): + cache.clear() + + async def test_method_decorator(self): + class TestView(View): + @method_decorator(ratelimit(key='ip', rate='1/m', block=False)) + async def post(self, request): + return request.limited + + view = TestView.as_view() + + assert not await view(rf.post('/')) + assert await view(rf.post('/')) + + async def test_class_decorator(self): + @method_decorator(ratelimit(key='ip', rate='1/m', block=False), + name='get') + class TestView(View): + async def get(self, request): + return request.limited + + view = TestView.as_view() + + assert not await view(rf.get('/')) + assert await view(rf.get('/')) + + async def test_wrap_view(self): + class TestView(View): + async def get(self, request): + return request.limited + + view = TestView.as_view() + wrapped = ratelimit(key='ip', rate='1/m', block=False)(view) + + assert not await wrapped(rf.get('/')) + assert await wrapped(rf.get('/')) + + async def test_methods_counted_separately(self): + class TestView(View): + @method_decorator(ratelimit(key='ip', rate='1/m', + method='GET', block=False)) + async def get(self, request): + return request.limited + + @method_decorator(ratelimit(key='ip', rate='1/m', + method='POST', block=False)) + async def post(self, request): + return request.limited + + view = TestView.as_view() + + assert not await view(rf.get('/')) + assert await view(rf.get('/')) + assert not await view(rf.post('/')) + + async def test_views_counted_separately(self): + class TestView(View): + @method_decorator(ratelimit(key='ip', rate='1/m', + method='GET', block=False)) + async def get(self, request): + return request.limited + + class AnotherTestView(View): + @method_decorator(ratelimit(key='ip', rate='1/m', + method='GET', block=False)) + async def get(self, request): + return request.limited + + test_view = TestView.as_view() + another_view = AnotherTestView.as_view() + + assert not await test_view(rf.get('/')) + assert await test_view(rf.get('/')) + assert not await another_view(rf.get('/')) + + +class AsyncCacheFailTests(TestCase): + @override_settings(RATELIMIT_USE_CACHE='fake-cache') + async def test_bad_cache(self): + @ratelimit(key='ip', rate='1/m', block=False) + async def view(request): + return request.limited + + with self.assertRaises(InvalidCacheBackendError): + await view(rf.post('/')) + + @override_settings(RATELIMIT_USE_CACHE='connection-errors') + async def test_limit_on_cache_connection_error(self): + @ratelimit(key='ip', rate='10/m', block=False) + async def view(request): + return request.limited + + assert await view(rf.post('/')) + + @override_settings(RATELIMIT_USE_CACHE='connection-errors', + RATELIMIT_FAIL_OPEN=True) + async def test_fail_open_setting(self): + @ratelimit(key='ip', rate='1/m', block=False) + async def view(request): + return request.limited + + assert not await view(rf.get('/')) + assert not await view(rf.get('/')) + + @override_settings(RATELIMIT_USE_CACHE='connection-errors') + async def test_is_ratelimited_cache_connection_error_without_increment(self): + async def not_increment(request): + return await ais_ratelimited(request, increment=False, + method=ais_ratelimited.ALL, key='ip', + rate='1/m', group='a') + + assert not await not_increment(rf.get('/')) + assert not await not_increment(rf.get('/')) + + @override_settings(RATELIMIT_USE_CACHE='connection-errors') + async def test_is_ratelimited_cache_connection_error_with_increment(self): + async def do_increment(request): + return await ais_ratelimited(request, increment=True, + method=ais_ratelimited.ALL, key='ip', + rate='1/m', group='a') + + assert await do_increment(rf.get('/')) + assert await do_increment(rf.get('/')) + + @override_settings(RATELIMIT_USE_CACHE='connection-errors-redis') + async def test_is_ratelimited_cache_connection_error_with_increment_redis(self): + async def do_increment(request): + return ais_ratelimited(request, increment=True, + method=ais_ratelimited.ALL, key='ip', + rate='1/m', group='a') + + assert await do_increment(rf.get('/')) + assert await do_increment(rf.get('/')) + + @override_settings(RATELIMIT_USE_CACHE='instant-expiration') + async def test_cache_timeout(self): + @ratelimit(key='ip', rate='1/m') + async def view(request): + return True + + assert await view(rf.get('/')) + assert await view(rf.get('/')) diff --git a/django_ratelimit/tests.py b/django_ratelimit/tests/test_sync.py similarity index 98% rename from django_ratelimit/tests.py rename to django_ratelimit/tests/test_sync.py index a58c89e..2ff49d7 100644 --- a/django_ratelimit/tests.py +++ b/django_ratelimit/tests/test_sync.py @@ -12,6 +12,7 @@ from django_ratelimit.core import (get_usage, is_ratelimited, _split_rate, _get_ip) +from . import my_ip, mykey, CustomRatelimitedException rf = RequestFactory() @@ -38,21 +39,7 @@ def test_simple(self): assert o == _split_rate(i) -def callable_rate(group, request): - if request.user.is_authenticated: - return None - return (0, 1) - - -def mykey(group, request): - return request.META['REMOTE_ADDR'][::-1] - - -class CustomRatelimitedException(Exception): - pass - - -class RatelimitTests(TestCase): +class RatelimitSyncTests(TestCase): def setUp(self): cache.clear() @@ -72,6 +59,14 @@ def view(request): assert not view(rf.get('/')), 'First request works.' assert view(rf.get('/')), 'Second request is limited' + + async def test_ip_async(self): + @ratelimit(key='ip', rate='1/m', block=False) + async def view(request): + return request.limited + + assert not await view(rf.get('/')), 'First request works.' + assert await view(rf.get('/')), 'Second request is limited' def test_block(self): @ratelimit(key='ip', rate='1/m') @@ -621,10 +616,6 @@ def view(request): assert view(rf.get('/')) -def my_ip(req): - return req.META['MY_THING'] - - class IpMetaTests(TestCase): def test_default(self): req = rf.get('/') diff --git a/tox.ini b/tox.ini index 636583e..77ff112 100644 --- a/tox.ini +++ b/tox.ini @@ -17,6 +17,7 @@ deps = django42: Django>=4.2,<4.3 django50: Django>=5.0a1,<5.1 djangomain: https://github.com/django/django/archive/main.tar.gz + asgiref>3.7.<4.0 pymemcache>=4.0,<5.0 django-redis>=5.2,<6.0 flake8