Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support and test for async #300

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 113 additions & 35 deletions django_ratelimit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from django_ratelimit import ALL, UNSAFE


__all__ = ['is_ratelimited', 'get_usage']
__all__ = ['is_ratelimited', 'ais_ratelimited', 'get_usage', 'aget_usage']

_PERIODS = {
's': 1,
Expand Down Expand Up @@ -156,9 +156,30 @@ def is_ratelimited(request, group=None, fn=None, key=None, rate=None,

return usage['should_limit']

async def ais_ratelimited(request, group=None, fn=None, key=None, rate=None,
method=ALL, increment=False):
usage = await aget_usage(request, group, fn, key, rate, method, increment)
if usage is None:
return False

return usage['should_limit']


def get_usage(request, group=None, fn=None, key=None, rate=None, method=ALL,
increment=False):
usage = _get_usage(request, group, fn, key, rate, method, increment)
if usage is not None:
return usage()

async def aget_usage(request, group=None, fn=None, key=None, rate=None, method=ALL,
increment=False):
usage = _get_usage(request, group, fn, key, rate, method, increment, is_async=True)
if usage is not None:
return await usage()


def _get_usage(request, group=None, fn=None, key=None, rate=None, method=ALL,
increment=False, is_async=False):
if group is None and fn is None:
raise ImproperlyConfigured('get_usage must be called with either '
'`group` or `fn` arguments')
Expand Down Expand Up @@ -227,45 +248,102 @@ def get_usage(request, group=None, fn=None, key=None, rate=None, method=ALL,
cache_key = _make_cache_key(group, window, rate, value, method)

count = None
try:
added = cache.add(cache_key, initial_value, period + EXPIRATION_FUDGE)
except socket.gaierror: # for redis
added = False
if added:
count = initial_value
if is_async:
async def inner():
daniel-brenot marked this conversation as resolved.
Show resolved Hide resolved
try:
# Some caches don't have an async implementation
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the aadd methods are defined in BaseCache (see the next comment down for the link) so in theory any cache should have access to them

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tried it before, and if someone is using an old cache this can fail. This is just to support caches that may not define the cache interface with those methods. Test implementations can also be missing this.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify what you mean by "old cache"? Is it an implementation that doesn't inherit from BaseCache?

if hasattr(cache, 'aadd'):
added = await cache.aadd(cache_key, initial_value, period + EXPIRATION_FUDGE)
else:
added = cache.add(cache_key, initial_value, period + EXPIRATION_FUDGE)
except socket.gaierror: # for redis
added = False
if added:
count = initial_value
else:
if increment:
try:
# python3-memcached will throw a ValueError if the server is
# unavailable or (somehow) the key doesn't exist. redis, on the
# other hand, simply returns None.
if hasattr(cache, 'aincr'):
count = await cache.aincr(cache_key)
else:
count = cache.incr(cache_key)
except ValueError:
pass
else:
if hasattr(cache, 'aget'):
count = await cache.aget(cache_key, initial_value)
else:
count = cache.get(cache_key, initial_value)

# Getting or setting the count from the cache failed
if count is None or count is False:
if getattr(settings, 'RATELIMIT_FAIL_OPEN', False):
return None
return {
'count': 0,
'limit': 0,
'should_limit': True,
'time_left': -1,
}

time_left = window - int(time.time())
return {
'count': count,
'limit': limit,
'should_limit': count > limit,
'time_left': time_left,
}
else:
if increment:
def inner():
try:
# python3-memcached will throw a ValueError if the server is
# unavailable or (somehow) the key doesn't exist. redis, on the
# other hand, simply returns None.
count = cache.incr(cache_key)
except ValueError:
pass
else:
count = cache.get(cache_key, initial_value)

# Getting or setting the count from the cache failed
if count is None or count is False:
if getattr(settings, 'RATELIMIT_FAIL_OPEN', False):
return None
return {
'count': 0,
'limit': 0,
'should_limit': True,
'time_left': -1,
}

time_left = window - int(time.time())
return {
'count': count,
'limit': limit,
'should_limit': count > limit,
'time_left': time_left,
}
added = cache.add(cache_key, initial_value, period + EXPIRATION_FUDGE)
except socket.gaierror: # for redis
added = False
if added:
count = initial_value
else:
if increment:
try:
# python3-memcached will throw a ValueError if the server is
# unavailable or (somehow) the key doesn't exist. redis, on the
# other hand, simply returns None.
count = cache.incr(cache_key)
except ValueError:
pass
else:
count = cache.get(cache_key, initial_value)

# Getting or setting the count from the cache failed
if count is None or count is False:
if getattr(settings, 'RATELIMIT_FAIL_OPEN', False):
return None
return {
'count': 0,
'limit': 0,
'should_limit': True,
'time_left': -1,
}

time_left = window - int(time.time())
return {
'count': count,
'limit': limit,
'should_limit': count > limit,
'time_left': time_left,
}

return inner



is_ratelimited.ALL = ALL
is_ratelimited.UNSAFE = UNSAFE
ais_ratelimited.ALL = ALL
ais_ratelimited.UNSAFE = UNSAFE
get_usage.ALL = ALL
get_usage.UNSAFE = UNSAFE
aget_usage.ALL = ALL
aget_usage.UNSAFE = UNSAFE
43 changes: 29 additions & 14 deletions django_ratelimit/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,44 @@

from django_ratelimit import ALL, UNSAFE
from django_ratelimit.exceptions import Ratelimited
from django_ratelimit.core import is_ratelimited
from django_ratelimit.core import is_ratelimited, ais_ratelimited
from asgiref.sync import iscoroutinefunction


__all__ = ['ratelimit']


def ratelimit(group=None, key=None, rate=None, method=ALL, block=True):
def decorator(fn):
@wraps(fn)
def _wrapped(request, *args, **kw):
old_limited = getattr(request, 'limited', False)
ratelimited = is_ratelimited(request=request, group=group, fn=fn,
key=key, rate=rate, method=method,
increment=True)
request.limited = ratelimited or old_limited
if ratelimited and block:
cls = getattr(
settings, 'RATELIMIT_EXCEPTION_CLASS', Ratelimited)
raise (import_string(cls) if isinstance(cls, str) else cls)()
return fn(request, *args, **kw)
if iscoroutinefunction(fn):
Comment on lines 15 to +17
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not something to change here, but I wonder if having a second aratelimit decorator might be simpler / clearer. I know we've only got 3 months of 3.2 official support left, but I think it might make it easier to say "aratelimit requires Django >= 4". "Explicit is better than implicit," after all

Copy link

@mlissner mlissner Jan 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Async is barely complete even in Django 5.0, so I doubt anybody would care if 3.2 doesn't support the async decorator. If you're doing async work, you're probably running the latest version of Django.

@wraps(fn)
async def _wrapped(request, *args, **kw):
old_limited = getattr(request, 'limited', False)
ratelimited = await ais_ratelimited(request=request, group=group, fn=fn,
key=key, rate=rate, method=method,
increment=True)
request.limited = ratelimited or old_limited
if ratelimited and block:
cls = getattr(
settings, 'RATELIMIT_EXCEPTION_CLASS', Ratelimited)
raise (import_string(cls) if isinstance(cls, str) else cls)()
return await fn(request, *args, **kw)
else:
@wraps(fn)
def _wrapped(request, *args, **kw):
old_limited = getattr(request, 'limited', False)
ratelimited = is_ratelimited(request=request, group=group, fn=fn,
key=key, rate=rate, method=method,
increment=True)
request.limited = ratelimited or old_limited
if ratelimited and block:
cls = getattr(
settings, 'RATELIMIT_EXCEPTION_CLASS', Ratelimited)
raise (import_string(cls) if isinstance(cls, str) else cls)()
return fn(request, *args, **kw)
return _wrapped
return decorator

return decorator

ratelimit.ALL = ALL
ratelimit.UNSAFE = UNSAFE
14 changes: 14 additions & 0 deletions django_ratelimit/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@

def my_ip(req):
return req.META['MY_THING']

def callable_rate(group, request):
if request.user.is_authenticated:
return None
return (0, 1)

def mykey(group, request):
return request.META['REMOTE_ADDR'][::-1]

class CustomRatelimitedException(Exception):
pass
Loading
Loading