From 3c86a1eb58457fa3cc5e6a3c9374f3b6b27ab655 Mon Sep 17 00:00:00 2001 From: Danang Massandy Date: Wed, 30 Oct 2024 11:52:23 +0000 Subject: [PATCH 01/13] add ratelimiter model --- .../gap_api/migrations/0004_apiratelimiter.py | 26 ++++ django_project/gap_api/models/__init__.py | 1 + django_project/gap_api/models/rate_limiter.py | 128 ++++++++++++++++++ 3 files changed, 155 insertions(+) create mode 100644 django_project/gap_api/migrations/0004_apiratelimiter.py create mode 100644 django_project/gap_api/models/rate_limiter.py diff --git a/django_project/gap_api/migrations/0004_apiratelimiter.py b/django_project/gap_api/migrations/0004_apiratelimiter.py new file mode 100644 index 0000000..92f92cc --- /dev/null +++ b/django_project/gap_api/migrations/0004_apiratelimiter.py @@ -0,0 +1,26 @@ +# Generated by Django 4.2.7 on 2024-10-30 11:51 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ('gap_api', '0003_location_location_user_locationname'), + ] + + operations = [ + migrations.CreateModel( + name='APIRateLimiter', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('minute_limit', models.IntegerField()), + ('hour_limit', models.IntegerField()), + ('day_limit', models.IntegerField()), + ('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, to=settings.AUTH_USER_MODEL)), + ], + ), + ] diff --git a/django_project/gap_api/models/__init__.py b/django_project/gap_api/models/__init__.py index 7317e7b..2ced82d 100644 --- a/django_project/gap_api/models/__init__.py +++ b/django_project/gap_api/models/__init__.py @@ -1,3 +1,4 @@ from gap_api.models.api_request_log import * # noqa from gap_api.models.api_config import * # noqa from gap_api.models.location import * # noqa +from gap_api.models.rate_limiter import * # noqa diff --git a/django_project/gap_api/models/rate_limiter.py b/django_project/gap_api/models/rate_limiter.py new file mode 100644 index 0000000..d9f52e4 --- /dev/null +++ b/django_project/gap_api/models/rate_limiter.py @@ -0,0 +1,128 @@ +# coding=utf-8 +""" +Tomorrow Now GAP API. + +.. note:: Models for Rate Limiter +""" + +from django.db import models +from django.conf import settings +from django.core.cache import cache +from django.db.models.signals import post_save +from django.dispatch import receiver + + +class APIRateLimiter(models.Model): + """Models that stores GAP API rate limiter.""" + + GLOBAL_CACHE_KEY = 'gap-api-ratelimit-global' + + user = models.ForeignKey( + settings.AUTH_USER_MODEL, + on_delete=models.SET_NULL, + null=True, + blank=True + ) + minute_limit = models.IntegerField() + hour_limit = models.IntegerField() + day_limit = models.IntegerField() + + @property + def config_name(self): + """Return config name.""" + if self.user: + return self.user.username + return 'global' + + @property + def cache_key(self): + """Return cache key for this config.""" + if self.user: + return f'gap-api-ratelimit-{self.user.id}' + return APIRateLimiter.GLOBAL_CACHE_KEY + + @property + def cache_value(self): + """Get dict cache value.""" + return { + 'minute': self.minute_limit, + 'hour': self.hour_limit, + 'day': self.day_limit + } + + def set_cache(self): + """Set rate limit to cache.""" + cache.set( + self.cache_key, + f'{self.minute_limit}:{self.hour_limit}:{self.day_limit}', + timeout=None + ) + + def clear_cache(self): + """Clear cache for this config.""" + cache.delete(self.cache_key) + + @staticmethod + def parse_cache_value(cache_str: str): + """Parse cache value.""" + values = cache_str.split(':') + return { + 'minute': values[0], + 'hour': values[1], + 'day': values[2] + } + + @staticmethod + def get_global_config(): + """Get global config cache.""" + config_cache = cache.get(APIRateLimiter.GLOBAL_CACHE_KEY, None) + if config_cache: + return APIRateLimiter.parse_cache_value(config_cache) + + limit = APIRateLimiter.objects.filter( + user=None + ).first() + if limit: + # set to cache + limit.set_cache() + return limit.cache_value + + return None + + @staticmethod + def get_config(user): + """Return config for given user.""" + cache_key = f'gap-api-ratelimit-{user.id}' + config_cache = cache.get(cache_key, None) + + if config_cache == 'global': + # use global config + pass + elif config_cache is None: + # find from table + limit = APIRateLimiter.objects.filter( + user=user + ).first() + + if limit: + # set to cache if found + limit.set_cache() + return limit.cache_value + else: + # set to use global config + cache.set(cache_key, 'global') + else: + # parse config for the user + return APIRateLimiter.parse_cache_value(config_cache) + + return APIRateLimiter.get_global_config() + + +@receiver(post_save, sender=APIRateLimiter) +def ratelimiter_post_create( + sender, instance: APIRateLimiter, created, *args, **kwargs): + """Clear cache after saving the object.""" + if created: + return + + cache.delete(instance.cache_key) From 2d33d09578b604472b06e30d6b219bf0cbdc9f74 Mon Sep 17 00:00:00 2001 From: Danang Massandy Date: Wed, 30 Oct 2024 12:22:44 +0000 Subject: [PATCH 02/13] add admin and fixtures --- django_project/gap_api/admin.py | 34 ++++++++++++++++++- .../gap_api/fixtures/2.apiratelimiter.json | 12 +++++++ 2 files changed, 45 insertions(+), 1 deletion(-) create mode 100644 django_project/gap_api/fixtures/2.apiratelimiter.json diff --git a/django_project/gap_api/admin.py b/django_project/gap_api/admin.py index eac3535..ff752f2 100644 --- a/django_project/gap_api/admin.py +++ b/django_project/gap_api/admin.py @@ -6,15 +6,23 @@ """ import random +import json from django.contrib import admin from django.db.models import Count, TextField from django.db.models.fields.json import KeyTextTransform from django.db.models.functions import TruncDay, Cast +from django.http import HttpResponse +from django.core.serializers.json import DjangoJSONEncoder from rest_framework_tracking.admin import APIRequestLogAdmin from rest_framework_tracking.models import APIRequestLog as BaseAPIRequestLog from gap.models import DatasetType -from gap_api.models import APIRequestLog, DatasetTypeAPIConfig, Location +from gap_api.models import ( + APIRequestLog, + DatasetTypeAPIConfig, + Location, + APIRateLimiter +) admin.site.unregister(BaseAPIRequestLog) @@ -182,6 +190,30 @@ class LocationAdmin(admin.ModelAdmin): list_filter = ('user',) +@admin.action(description='Export rate limiter as json') +def export_rate_limiter_as_json(modeladmin, request, queryset): + """Download rate limiter.""" + fields_to_include = [ + 'pk', 'user_id', 'minute_limit', 'hour_limit', 'day_limit'] + data = list(queryset.all().values(*fields_to_include)) + + # Convert the data to JSON + response_data = json.dumps(data, cls=DjangoJSONEncoder) + + # Create the HttpResponse with the correct content_type for JSON + response = HttpResponse(response_data, content_type='application/json') + response['Content-Disposition'] = 'attachment; filename=rate_limiter.json' + return response + + +class APIRateLimiterAdmin(admin.ModelAdmin): + """Admin class for APIRateLimiter.""" + + list_display = ('config_name', 'minute_limit', 'hour_limit', 'day_limit',) + actions = (export_rate_limiter_as_json,) + + admin.site.register(APIRequestLog, GapAPIRequestLogAdmin) admin.site.register(DatasetTypeAPIConfig, GapAPIDatasetTypeConfigAdmin) admin.site.register(Location, LocationAdmin) +admin.site.register(APIRateLimiter, APIRateLimiterAdmin) diff --git a/django_project/gap_api/fixtures/2.apiratelimiter.json b/django_project/gap_api/fixtures/2.apiratelimiter.json new file mode 100644 index 0000000..e737c1e --- /dev/null +++ b/django_project/gap_api/fixtures/2.apiratelimiter.json @@ -0,0 +1,12 @@ +[ +{ + "model": "gap_api.apiratelimiter", + "pk": 1, + "fields": { + "user": null, + "minute_limit": 100, + "hour_limit": 1000, + "day_limit": 10000 + } +} +] From 4fd72920c8cabd21c3987477fdb95a8af31f476a Mon Sep 17 00:00:00 2001 From: Danang Massandy Date: Wed, 30 Oct 2024 13:12:42 +0000 Subject: [PATCH 03/13] add throttle class --- .../gap_api/api_views/measurement.py | 3 +- django_project/gap_api/mixins/__init__.py | 1 + django_project/gap_api/mixins/rate_limiter.py | 144 ++++++++++++++++++ 3 files changed, 147 insertions(+), 1 deletion(-) create mode 100644 django_project/gap_api/mixins/rate_limiter.py diff --git a/django_project/gap_api/api_views/measurement.py b/django_project/gap_api/api_views/measurement.py index eb0e818..cec8196 100644 --- a/django_project/gap_api/api_views/measurement.py +++ b/django_project/gap_api/api_views/measurement.py @@ -39,7 +39,7 @@ from gap_api.models import DatasetTypeAPIConfig, Location from gap_api.serializers.common import APIErrorSerializer from gap_api.utils.helper import ApiTag -from gap_api.mixins import GAPAPILoggingMixin +from gap_api.mixins import GAPAPILoggingMixin, CounterSlidingWindowThrottle def attribute_list(): @@ -78,6 +78,7 @@ class MeasurementAPI(GAPAPILoggingMixin, APIView): date_format = '%Y-%m-%d' time_format = '%H:%M:%S' permission_classes = [IsAuthenticated] + throttle_classes = [CounterSlidingWindowThrottle] api_parameters = [ openapi.Parameter( 'attributes', diff --git a/django_project/gap_api/mixins/__init__.py b/django_project/gap_api/mixins/__init__.py index 5fb58ba..4be42fc 100644 --- a/django_project/gap_api/mixins/__init__.py +++ b/django_project/gap_api/mixins/__init__.py @@ -1 +1,2 @@ from gap_api.mixins.logging import * # noqa +from gap_api.mixins.rate_limiter import * # noqa diff --git a/django_project/gap_api/mixins/rate_limiter.py b/django_project/gap_api/mixins/rate_limiter.py new file mode 100644 index 0000000..56fb572 --- /dev/null +++ b/django_project/gap_api/mixins/rate_limiter.py @@ -0,0 +1,144 @@ +# coding=utf-8 +""" +Tomorrow Now GAP API. + +.. note:: Mixin for API Tracking +""" + +import time +from redis import Redis +from django.core.cache import cache +from rest_framework.throttling import BaseThrottle + + +class RateLimiter: + """RateLimiter using sliding window counter.""" + + def __init__(self, user_id, rate_limits): + """Initialize rate limiter. + + :param user_id: User identifier. + :param rate_limits: Dictionary with limit_duration_in_minutes: + max_requests. E.g., {1: 100, 60: 1000, 1440: 10000} + (100 requests per minute, 1000 per hour, 10,000 per day). + """ + self.user_id = user_id + self.rate_limits = rate_limits + self.redis: Redis = cache._cache.get_client() + + def _get_current_minute(self): + """Get the current timestamp rounded to the nearest min.""" + return int(time.time() // 60) + + def _get_current_hour(self): + """Get the current timestamp rounded to the nearest hour.""" + return int(time.time() // 3600) + + def _get_redis_key(self, granularity): + """Get the Redis key for this user and time granularity.""" + if granularity == 'minute': + return f"rate_limit:minute:{self.user_id}" + elif granularity == 'hour': + return f"rate_limit:hour:{self.user_id}" + + def _increment_request_count(self): + """Increment the request count for the current minute and hour.""" + current_minute = self._get_current_minute() + current_hour = self._get_current_hour() + + # Increment minute-level requests + minute_key = self._get_redis_key('minute') + self.redis.hincrby(minute_key, current_minute, 1) + # 2 hours expiration for minute data + self.redis.expire(minute_key, 2 * 60 * 60) + + # Increment hour-level requests (for daily limit) + hour_key = self._get_redis_key('hour') + self.redis.hincrby(hour_key, current_hour, 1) + # 25 hours expiration for hour data + self.redis.expire(hour_key, 25 * 60 * 60) + + # Clean up old minute and hour entries + self._cleanup_old_entries() + + def _cleanup_old_entries(self): + """Remove counters older than the longest rate limit window.""" + current_minute = self._get_current_minute() + current_hour = self._get_current_hour() + + # Cleanup minute-level data (older than + # the longest minute-based window) + longest_minute_window = max( + [duration for duration in self.rate_limits.keys() if + duration < 60] + ) + minute_cutoff = current_minute - longest_minute_window + minute_key = self._get_redis_key('minute') + minute_buckets = self.redis.hkeys(minute_key) + for minute in minute_buckets: + if int(minute) < minute_cutoff: + self.redis.hdel(minute_key, minute) + + # Cleanup hour-level data (older than the longest hour-based window, + # i.e., 24 hours) + hour_cutoff = current_hour - 24 + hour_key = self._get_redis_key('hour') + hour_buckets = self.redis.hkeys(hour_key) + for hour in hour_buckets: + if int(hour) < hour_cutoff: + self.redis.hdel(hour_key, hour) + + def _get_request_count(self, duration_in_minutes): + """Get the total request count for the last `duration_in_minutes`.""" + if duration_in_minutes < 60: + # Minute-based rate limit + # (for short durations like 1 minute or 1 hour) + current_minute = self._get_current_minute() + redis_key = self._get_redis_key('minute') + total_count = 0 + for i in range(duration_in_minutes): + minute = current_minute - i + count = self.redis.hget(redis_key, minute) + if count: + total_count += int(count) + return total_count + + else: + # Hour-based rate limit (for longer durations like a day) + current_hour = self._get_current_hour() + redis_key = self._get_redis_key('hour') + total_count = 0 + for i in range(duration_in_minutes // 60): + hour = current_hour - i + count = self.redis.hget(redis_key, hour) + if count: + total_count += int(count) + return total_count + + def is_rate_limited(self): + """Check if the user is rate-limited based on defined rate limits.""" + for duration_in_minutes, max_requests in self.rate_limits.items(): + request_count = self._get_request_count(duration_in_minutes) + if request_count >= max_requests: + return True + return False + + +class CounterSlidingWindowThrottle(BaseThrottle): + """Custom throttle class using sliding window counter.""" + + def allow_request(self, request, view): + """Check whether request is allowed.""" + rate_limits = { + 1: 100, # 100 requests per minute + 60: 1000, # 1000 requests per hour + 1440: 10000 # 10,000 requests per day + } + + rate_limiter = RateLimiter(request.user.id, rate_limits) + + if rate_limiter.is_rate_limited(): + return False + + rate_limiter._increment_request_count() + return True From fe9f1b2912b2663c2039635ba07a62b153a671ac Mon Sep 17 00:00:00 2001 From: Danang Massandy Date: Wed, 30 Oct 2024 13:13:09 +0000 Subject: [PATCH 04/13] fix lint --- django_project/gap_api/mixins/rate_limiter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/django_project/gap_api/mixins/rate_limiter.py b/django_project/gap_api/mixins/rate_limiter.py index 56fb572..1d43a53 100644 --- a/django_project/gap_api/mixins/rate_limiter.py +++ b/django_project/gap_api/mixins/rate_limiter.py @@ -139,6 +139,6 @@ def allow_request(self, request, view): if rate_limiter.is_rate_limited(): return False - + rate_limiter._increment_request_count() return True From 4eb8faa11cb00008f50f7df08f98fa00037abcaa Mon Sep 17 00:00:00 2001 From: Danang Massandy Date: Wed, 30 Oct 2024 18:21:22 +0000 Subject: [PATCH 05/13] add tests --- deployment/docker/requirements-dev.txt | 5 +- django_project/gap_api/mixins/rate_limiter.py | 18 ++- .../gap_api/tests/test_api_request_log.py | 7 + .../gap_api/tests/test_measurement_api.py | 19 ++- .../gap_api/tests/test_rate_limiter.py | 141 ++++++++++++++++++ 5 files changed, 182 insertions(+), 8 deletions(-) create mode 100644 django_project/gap_api/tests/test_rate_limiter.py diff --git a/deployment/docker/requirements-dev.txt b/deployment/docker/requirements-dev.txt index a2d4faa..f92ca9e 100644 --- a/deployment/docker/requirements-dev.txt +++ b/deployment/docker/requirements-dev.txt @@ -21,4 +21,7 @@ pytest-django requests-mock # memory profiler -memory-profiler \ No newline at end of file +memory-profiler + +# fakeredis +fakeredis==2.26.1 \ No newline at end of file diff --git a/django_project/gap_api/mixins/rate_limiter.py b/django_project/gap_api/mixins/rate_limiter.py index 1d43a53..9f5ac07 100644 --- a/django_project/gap_api/mixins/rate_limiter.py +++ b/django_project/gap_api/mixins/rate_limiter.py @@ -40,6 +40,10 @@ def _get_redis_key(self, granularity): return f"rate_limit:minute:{self.user_id}" elif granularity == 'hour': return f"rate_limit:hour:{self.user_id}" + elif granularity == 'day': + return f"ratelimit:{self.user_id}:day" + else: + raise ValueError("Unsupported granularity") def _increment_request_count(self): """Increment the request count for the current minute and hour.""" @@ -123,6 +127,14 @@ def is_rate_limited(self): return True return False + def is_request_allowed(self): + """Check and increment the counter if request is allowed.""" + if self.is_rate_limited(): + return False + + self._increment_request_count() + return True + class CounterSlidingWindowThrottle(BaseThrottle): """Custom throttle class using sliding window counter.""" @@ -137,8 +149,4 @@ def allow_request(self, request, view): rate_limiter = RateLimiter(request.user.id, rate_limits) - if rate_limiter.is_rate_limited(): - return False - - rate_limiter._increment_request_count() - return True + return rate_limiter.is_request_allowed() diff --git a/django_project/gap_api/tests/test_api_request_log.py b/django_project/gap_api/tests/test_api_request_log.py index 298a681..cd4dc52 100644 --- a/django_project/gap_api/tests/test_api_request_log.py +++ b/django_project/gap_api/tests/test_api_request_log.py @@ -10,6 +10,7 @@ import datetime from django.test import TestCase, RequestFactory, override_settings from django.contrib.admin import ModelAdmin +from fakeredis import FakeConnection from core.factories import UserF from gap.models import DatasetType @@ -29,6 +30,12 @@ class MockRequestObj(object): CACHES={ 'default': { 'BACKEND': 'django.core.cache.backends.redis.RedisCache', + 'LOCATION': [ + 'redis://127.0.0.1:6379', + ], + 'OPTIONS': { + 'connection_class': FakeConnection + } } } ) diff --git a/django_project/gap_api/tests/test_measurement_api.py b/django_project/gap_api/tests/test_measurement_api.py index 083412e..2b3181b 100644 --- a/django_project/gap_api/tests/test_measurement_api.py +++ b/django_project/gap_api/tests/test_measurement_api.py @@ -6,12 +6,14 @@ """ from datetime import datetime -from typing import List +from typing import List, Tuple from unittest.mock import patch +from django.test import override_settings from django.contrib.gis.geos import Polygon, MultiPolygon, Point from django.urls import reverse from rest_framework.exceptions import ValidationError +from fakeredis import FakeConnection from core.tests.common import FakeResolverMatchV1, BaseAPIViewTest from gap.factories import ( @@ -36,7 +38,7 @@ def __init__( location_input: DatasetReaderInput, start_date: datetime, end_date: datetime, output_type=DatasetReaderOutputType.JSON, - altitudes: (float, float) = None + altitudes: Tuple[float, float] = None ) -> None: """Initialize MockDatasetReader class.""" super().__init__( @@ -58,6 +60,19 @@ def get_data_values(self) -> DatasetReaderValue: ) +@override_settings( + CACHES={ + 'default': { + 'BACKEND': 'django.core.cache.backends.redis.RedisCache', + 'LOCATION': [ + 'redis://127.0.0.1:6379', + ], + 'OPTIONS': { + 'connection_class': FakeConnection + } + } + } +) class CommonMeasurementAPITest(BaseAPIViewTest): """Common class for Measurement API Test.""" diff --git a/django_project/gap_api/tests/test_rate_limiter.py b/django_project/gap_api/tests/test_rate_limiter.py new file mode 100644 index 0000000..007f48c --- /dev/null +++ b/django_project/gap_api/tests/test_rate_limiter.py @@ -0,0 +1,141 @@ +# coding=utf-8 +""" +Tomorrow Now GAP. + +.. note:: Unit tests for User API. +""" + +from django.test import TestCase, override_settings +from django.core.cache import cache +from fakeredis import FakeConnection + +from gap_api.mixins.rate_limiter import RateLimiter + + +@override_settings( + CACHES={ + 'default': { + 'BACKEND': 'django.core.cache.backends.redis.RedisCache', + 'LOCATION': [ + 'redis://127.0.0.1:6379', + ], + 'OPTIONS': { + 'connection_class': FakeConnection + } + } + } +) +class TestRateLimiter(TestCase): + """Unit test for RateLimiter class.""" + + def setUp(self): + """Set test class.""" + self.redis_client = cache._cache.get_client() + + def test_rate_limiter_allows_request(self): + """Test requests are allowed.""" + # Set up rate limiter with small rate limits for testing + rate_limiter = RateLimiter( + user_id="user123", rate_limits={1: 5, 60: 100, 1440: 1000}) + + # Send 5 requests, should all pass + for _ in range(5): + can_access = rate_limiter.is_request_allowed() + self.assertTrue(can_access) + + def test_rate_limiter_blocks_after_minute_limit(self): + """Test blocked after minute limit is exceeded.""" + # Set up rate limiter with small rate limits for testing + rate_limiter = RateLimiter( + user_id="user124", rate_limits={1: 3, 60: 100, 1440: 1000}) + + # Send 3 requests, all should pass + for _ in range(3): + can_access = rate_limiter.is_request_allowed() + self.assertTrue(can_access) + + # Send 1 more request, should be rate limited + can_access = rate_limiter.is_request_allowed() + self.assertFalse(can_access) + + def test_rate_limiter_resets_after_time(self): + """Test rate limit is reset.""" + # Set up rate limiter with a 1-minute limit of 3 requests + rate_limiter = RateLimiter(user_id="user125", rate_limits={1: 3}) + + # Send 3 requests, all should pass + for _ in range(3): + can_access = rate_limiter.is_request_allowed() + self.assertTrue(can_access) + + # Directly modify Redis to simulate time passing + # (clear the minute bucket) + current_minute = rate_limiter._get_current_minute() + minute_key = rate_limiter._get_redis_key('minute') + self.redis_client.hdel(minute_key, current_minute) + + # Send 1 more request, it should now pass + can_access = rate_limiter.is_request_allowed() + self.assertTrue(can_access) + + def test_rate_limiter_blocks_after_hour_limit(self): + """Test blocked after hour limit is exceeded.""" + # Set up rate limiter with a 60-minute (1 hour) limit of 10 requests + rate_limiter = RateLimiter(user_id="user126", rate_limits={60: 10}) + + # Manually set Redis data to simulate 10 requests in the past hour + current_hour = rate_limiter._get_current_hour() + hour_key = rate_limiter._get_redis_key('hour') + self.redis_client.hset(hour_key, current_hour, 10) + + # Send 1 more request, it should now be rate limited + can_access = rate_limiter.is_request_allowed() + self.assertFalse(can_access) + + def test_rate_limiter_blocks_after_day_limit(self): + """Test blocked after day limit is exceeded.""" + # Set up rate limiter with a 1440-minute (1 day) limit of 50 requests + rate_limiter = RateLimiter(user_id="user127", rate_limits={1440: 50}) + + # Manually set Redis data to simulate 50 requests in the past 24 hours + current_hour = rate_limiter._get_current_hour() + hour_key = rate_limiter._get_redis_key('hour') + for i in range(24): # Simulate requests for the past 24 hours + self.redis_client.hset(hour_key, current_hour - i, 5) + + # Send 1 more request, it should now be rate limited + can_access = rate_limiter.is_request_allowed() + self.assertFalse(can_access) + + def test_cleanup_old_entries(self): + """Test cleanup old entries.""" + # Set up rate limiter with a 1-minute and 60-minute limit for testing + rate_limiter = RateLimiter( + user_id="user128", rate_limits={1: 5, 60: 100}) + + # Manually set Redis data to simulate old minute and hour entries + current_minute = rate_limiter._get_current_minute() + current_hour = rate_limiter._get_current_hour() + + # Add old minute data beyond the longest minute-based window + # (assume limit is 1 minute) + minute_key = rate_limiter._get_redis_key('minute') + # Older than 1 minute window + self.redis_client.hset(minute_key, current_minute - 2, 10) + + # Add old hour data beyond the longest hour-based window + # (assume 24 hours) + hour_key = rate_limiter._get_redis_key('hour') + # Older than 24-hour window + self.redis_client.hset(hour_key, current_hour - 25, 10) + + # Perform a request to trigger cleanup + rate_limiter.is_request_allowed() + + # Verify old minute data is cleaned up + self.assertFalse( + self.redis_client.hexists(minute_key, current_minute - 2)) + + # Verify old hour data is cleaned up + self.assertFalse( + self.redis_client.hexists(hour_key, current_hour - 25)) From 58a814934883613908a4fb492b1dc2afdf7127f8 Mon Sep 17 00:00:00 2001 From: Danang Massandy Date: Wed, 30 Oct 2024 18:32:39 +0000 Subject: [PATCH 06/13] fix get_redis_key --- django_project/gap_api/mixins/rate_limiter.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/django_project/gap_api/mixins/rate_limiter.py b/django_project/gap_api/mixins/rate_limiter.py index 9f5ac07..f9906b3 100644 --- a/django_project/gap_api/mixins/rate_limiter.py +++ b/django_project/gap_api/mixins/rate_limiter.py @@ -40,8 +40,6 @@ def _get_redis_key(self, granularity): return f"rate_limit:minute:{self.user_id}" elif granularity == 'hour': return f"rate_limit:hour:{self.user_id}" - elif granularity == 'day': - return f"ratelimit:{self.user_id}:day" else: raise ValueError("Unsupported granularity") From a480d2416945006571b2d169e7563c6689e44742 Mon Sep 17 00:00:00 2001 From: Danang Massandy Date: Wed, 30 Oct 2024 18:40:35 +0000 Subject: [PATCH 07/13] add test for APIRateLimiter model --- django_project/gap_api/factories.py | 14 +++- django_project/gap_api/models/rate_limiter.py | 6 +- .../gap_api/tests/test_rate_limiter.py | 83 +++++++++++++++++++ 3 files changed, 99 insertions(+), 4 deletions(-) diff --git a/django_project/gap_api/factories.py b/django_project/gap_api/factories.py index 9651bf2..01d71ba 100644 --- a/django_project/gap_api/factories.py +++ b/django_project/gap_api/factories.py @@ -9,7 +9,7 @@ from django.contrib.gis.geos import Polygon, MultiPolygon from core.factories import UserF -from gap_api.models import APIRequestLog, Location +from gap_api.models import APIRequestLog, Location, APIRateLimiter class APIRequestLogFactory(DjangoModelFactory): @@ -51,3 +51,15 @@ class Meta: # noqa ) ) created_on = factory.Faker('date_time') + + +class APIRateLimiterFactory(DjangoModelFactory): + """Factory class for APIRateLimiter model.""" + + class Meta: # noqa + model = APIRateLimiter + + user = factory.SubFactory(UserF) + minute_limit = 10 + hour_limit = 100 + day_limit = 1000 diff --git a/django_project/gap_api/models/rate_limiter.py b/django_project/gap_api/models/rate_limiter.py index d9f52e4..fa6f8dd 100644 --- a/django_project/gap_api/models/rate_limiter.py +++ b/django_project/gap_api/models/rate_limiter.py @@ -67,9 +67,9 @@ def parse_cache_value(cache_str: str): """Parse cache value.""" values = cache_str.split(':') return { - 'minute': values[0], - 'hour': values[1], - 'day': values[2] + 'minute': int(values[0]), + 'hour': int(values[1]), + 'day': int(values[2]) } @staticmethod diff --git a/django_project/gap_api/tests/test_rate_limiter.py b/django_project/gap_api/tests/test_rate_limiter.py index 007f48c..a193c25 100644 --- a/django_project/gap_api/tests/test_rate_limiter.py +++ b/django_project/gap_api/tests/test_rate_limiter.py @@ -9,7 +9,9 @@ from django.core.cache import cache from fakeredis import FakeConnection +from gap_api.models.rate_limiter import APIRateLimiter from gap_api.mixins.rate_limiter import RateLimiter +from gap_api.factories import APIRateLimiterFactory @override_settings( @@ -139,3 +141,84 @@ def test_cleanup_old_entries(self): # Verify old hour data is cleaned up self.assertFalse( self.redis_client.hexists(hour_key, current_hour - 25)) + + +@override_settings( + CACHES={ + 'default': { + 'BACKEND': 'django.core.cache.backends.redis.RedisCache', + 'LOCATION': [ + 'redis://127.0.0.1:6379', + ], + 'OPTIONS': { + 'connection_class': FakeConnection + } + } + } +) +class TestAPIRateLimiterModel(TestCase): + """Unit test for APIRateLimiter model.""" + + def setUp(self): + """Set the test class.""" + self.global_rate_limiter = APIRateLimiterFactory.create( + user=None, + minute_limit=10, + hour_limit=100, + day_limit=1000, + ) + self.user_rate_limiter = APIRateLimiterFactory.create( + minute_limit=5, + hour_limit=50, + day_limit=500, + ) + + def test_set_cache_for_user(self): + """Test setting cache for a specific user.""" + # Set cache for user-specific config + self.user_rate_limiter.set_cache() + cached_value = cache.get(self.user_rate_limiter.cache_key) + self.assertEqual(cached_value, '5:50:500') + + def test_set_cache_for_global(self): + """Test setting cache for global config.""" + # Set cache for global config + self.global_rate_limiter.set_cache() + cached_value = cache.get(self.global_rate_limiter.cache_key) + self.assertEqual(cached_value, '10:100:1000') + + def test_clear_cache_for_user(self): + """Test clearing cache for a specific user.""" + self.user_rate_limiter.set_cache() + self.user_rate_limiter.clear_cache() + cached_value = cache.get(self.user_rate_limiter.cache_key) + self.assertIsNone(cached_value) + + def test_get_config_for_user(self): + """Test retrieving config for a specific user from cache.""" + self.user_rate_limiter.set_cache() + config = APIRateLimiter.get_config(self.user_rate_limiter.user) + self.assertEqual(config['minute'], 5) + self.assertEqual(config['hour'], 50) + self.assertEqual(config['day'], 500) + + def test_get_global_config(self): + """Test retrieving global config from cache.""" + self.global_rate_limiter.set_cache() + config = APIRateLimiter.get_global_config() + self.assertEqual(config['minute'], 10) + self.assertEqual(config['hour'], 100) + self.assertEqual(config['day'], 1000) + + def test_get_global_config_when_not_cached(self): + """Test retrieving global config when it is not cached.""" + # First clear cache to simulate no cache scenario + cache.delete(APIRateLimiter.GLOBAL_CACHE_KEY) + + # Ensure the global rate limiter is fetched and set in the cache + config = APIRateLimiter.get_global_config() + self.assertEqual(config['minute'], 10) + self.assertEqual(config['hour'], 100) + self.assertEqual(config['day'], 1000) + cached_value = cache.get(APIRateLimiter.GLOBAL_CACHE_KEY) + self.assertEqual(cached_value, '10:100:1000') From 922691c3c18121e43a004f23ccd5c17380a80a3c Mon Sep 17 00:00:00 2001 From: Danang Massandy Date: Wed, 30 Oct 2024 19:24:26 +0000 Subject: [PATCH 08/13] add waiting time in seconds --- django_project/gap_api/mixins/rate_limiter.py | 50 +++++++++++++++++-- .../gap_api/tests/test_rate_limiter.py | 39 +++++++++++++++ 2 files changed, 86 insertions(+), 3 deletions(-) diff --git a/django_project/gap_api/mixins/rate_limiter.py b/django_project/gap_api/mixins/rate_limiter.py index f9906b3..5b20da5 100644 --- a/django_project/gap_api/mixins/rate_limiter.py +++ b/django_project/gap_api/mixins/rate_limiter.py @@ -25,6 +25,7 @@ def __init__(self, user_id, rate_limits): self.user_id = user_id self.rate_limits = rate_limits self.redis: Redis = cache._cache.get_client() + self.exceeding_limits = [] def _get_current_minute(self): """Get the current timestamp rounded to the nearest min.""" @@ -119,10 +120,14 @@ def _get_request_count(self, duration_in_minutes): def is_rate_limited(self): """Check if the user is rate-limited based on defined rate limits.""" + self.exceeding_limits = [] for duration_in_minutes, max_requests in self.rate_limits.items(): request_count = self._get_request_count(duration_in_minutes) if request_count >= max_requests: - return True + self.exceeding_limits.append(duration_in_minutes) + + if len(self.exceeding_limits) > 0: + return True return False def is_request_allowed(self): @@ -133,6 +138,36 @@ def is_request_allowed(self): self._increment_request_count() return True + def get_waiting_time_in_seconds(self): + """ + Estimate the waiting time (in seconds) until the rate limit is lifted. + + This is calculated as the time remaining until the next window starts. + """ + current_time = time.time() + waiting_times = [] + + # check minute-level rate limit + if any(duration < 60 for duration in self.exceeding_limits): + next_minute_reset = (current_time // 60 + 1) * 60 + waiting_times.append(next_minute_reset - current_time) + + # check hour-level rate limit + if 60 in self.exceeding_limits: + next_hour_reset = (current_time // 3600 + 1) * 3600 + waiting_times.append(next_hour_reset - current_time) + + # check day-level rate limit + if 1440 in self.exceeding_limits: + next_day_reset = (current_time // 86400 + 1) * 86400 + waiting_times.append(next_day_reset - current_time) + + # Return the longest waiting time + if waiting_times: + return int(max(waiting_times)) + + return None + class CounterSlidingWindowThrottle(BaseThrottle): """Custom throttle class using sliding window counter.""" @@ -140,11 +175,20 @@ class CounterSlidingWindowThrottle(BaseThrottle): def allow_request(self, request, view): """Check whether request is allowed.""" rate_limits = { - 1: 100, # 100 requests per minute + 1: 3, # 100 requests per minute 60: 1000, # 1000 requests per hour 1440: 10000 # 10,000 requests per day } rate_limiter = RateLimiter(request.user.id, rate_limits) - return rate_limiter.is_request_allowed() + self.wait_time = None + is_allowed = rate_limiter.is_request_allowed() + if not is_allowed: + self.wait_time = rate_limiter.get_waiting_time_in_seconds() + + return is_allowed + + def wait(self): + """Return the waiting time in seconds.""" + return self.wait_time diff --git a/django_project/gap_api/tests/test_rate_limiter.py b/django_project/gap_api/tests/test_rate_limiter.py index a193c25..d4d91aa 100644 --- a/django_project/gap_api/tests/test_rate_limiter.py +++ b/django_project/gap_api/tests/test_rate_limiter.py @@ -142,6 +142,45 @@ def test_cleanup_old_entries(self): self.assertFalse( self.redis_client.hexists(hour_key, current_hour - 25)) + def test_get_waiting_time_in_seconds_for_minute_limit(self): + """Test get waiting time for minute limit.""" + rate_limiter = RateLimiter( + user_id="user129", rate_limits={1: 5, 60: 10, 1440: 100}) + rate_limiter.exceeding_limits = [1] + wait_time = rate_limiter.get_waiting_time_in_seconds() + self.assertTrue(0 < wait_time <= 60) + + def test_get_waiting_time_in_seconds_for_hour_limit(self): + """Test get waiting time for hour limit.""" + rate_limiter = RateLimiter( + user_id="user129", rate_limits={1: 5, 60: 10, 1440: 100}) + rate_limiter.exceeding_limits = [60] + wait_time = rate_limiter.get_waiting_time_in_seconds() + self.assertTrue(0 < wait_time <= 3600) + + def test_get_waiting_time_in_seconds_for_day_limit(self): + """Test get waiting time for day limit.""" + rate_limiter = RateLimiter( + user_id="user129", rate_limits={1: 5, 60: 10, 1440: 100}) + rate_limiter.exceeding_limits = [1440] + wait_time = rate_limiter.get_waiting_time_in_seconds() + self.assertTrue(0 < wait_time <= 86400) + + def test_get_waiting_time_in_seconds_under_limit(self): + """Test get waiting time when under limit.""" + rate_limiter = RateLimiter( + user_id="user129", rate_limits={1: 5, 60: 10, 1440: 100}) + wait_time = rate_limiter.get_waiting_time_in_seconds() + self.assertIsNone(wait_time) + + def test_get_waiting_time_in_seconds_multiple_limit(self): + """Test get waiting time when multiple limits are exceeded.""" + rate_limiter = RateLimiter( + user_id="user129", rate_limits={1: 5, 60: 10, 1440: 100}) + rate_limiter.exceeding_limits = [1, 60] + wait_time = rate_limiter.get_waiting_time_in_seconds() + self.assertTrue(0 < wait_time <= 3600) + @override_settings( CACHES={ From fbcd66f8434161f6d4a488089bca8a4629cc2bac Mon Sep 17 00:00:00 2001 From: Danang Massandy Date: Wed, 30 Oct 2024 19:56:07 +0000 Subject: [PATCH 09/13] fetch rate limit from model --- .../gap_api/fixtures/2.apiratelimiter.json | 6 +-- django_project/gap_api/mixins/rate_limiter.py | 44 ++++++++++++++++--- django_project/gap_api/models/rate_limiter.py | 17 ++++--- 3 files changed, 51 insertions(+), 16 deletions(-) diff --git a/django_project/gap_api/fixtures/2.apiratelimiter.json b/django_project/gap_api/fixtures/2.apiratelimiter.json index e737c1e..4ab764d 100644 --- a/django_project/gap_api/fixtures/2.apiratelimiter.json +++ b/django_project/gap_api/fixtures/2.apiratelimiter.json @@ -4,9 +4,9 @@ "pk": 1, "fields": { "user": null, - "minute_limit": 100, - "hour_limit": 1000, - "day_limit": 10000 + "minute_limit": 1000, + "hour_limit": 10000, + "day_limit": 100000 } } ] diff --git a/django_project/gap_api/mixins/rate_limiter.py b/django_project/gap_api/mixins/rate_limiter.py index 5b20da5..f07e0b8 100644 --- a/django_project/gap_api/mixins/rate_limiter.py +++ b/django_project/gap_api/mixins/rate_limiter.py @@ -10,6 +10,16 @@ from django.core.cache import cache from rest_framework.throttling import BaseThrottle +from gap_api.models.rate_limiter import APIRateLimiter + + +class RateLimitKey: + """Key for available rate limit.""" + + RATE_LIMIT_MINUTE_KEY = 1 + RATE_LIMIT_HOUR_KEY = 60 + RATE_LIMIT_DAY_KEY = 1440 + class RateLimiter: """RateLimiter using sliding window counter.""" @@ -153,12 +163,12 @@ def get_waiting_time_in_seconds(self): waiting_times.append(next_minute_reset - current_time) # check hour-level rate limit - if 60 in self.exceeding_limits: + if RateLimitKey.RATE_LIMIT_HOUR_KEY in self.exceeding_limits: next_hour_reset = (current_time // 3600 + 1) * 3600 waiting_times.append(next_hour_reset - current_time) # check day-level rate limit - if 1440 in self.exceeding_limits: + if RateLimitKey.RATE_LIMIT_DAY_KEY in self.exceeding_limits: next_day_reset = (current_time // 86400 + 1) * 86400 waiting_times.append(next_day_reset - current_time) @@ -172,13 +182,33 @@ def get_waiting_time_in_seconds(self): class CounterSlidingWindowThrottle(BaseThrottle): """Custom throttle class using sliding window counter.""" + def _fetch_rate_limit(self, user): + """Fetch rate limit for given user. + + if user does not have config, then it will use the global config. + """ + config = APIRateLimiter.get_config(user) + rate_limits = {} + + if config['minute'] != -1: + rate_limits[RateLimitKey.RATE_LIMIT_MINUTE_KEY] = config['minute'] + + if config['hour'] != -1: + rate_limits[RateLimitKey.RATE_LIMIT_HOUR_KEY] = config['hour'] + + if config['day'] != -1: + rate_limits[RateLimitKey.RATE_LIMIT_DAY_KEY] = config['day'] + + return rate_limits + def allow_request(self, request, view): """Check whether request is allowed.""" - rate_limits = { - 1: 3, # 100 requests per minute - 60: 1000, # 1000 requests per hour - 1440: 10000 # 10,000 requests per day - } + rate_limits = self._fetch_rate_limit(request.user) + + # check if rate_limit is disabled + # NOTE: the global config is 1k/min, 10k/hour, 100k/day from fixture + if len(rate_limits) == 0: + return True rate_limiter = RateLimiter(request.user.id, rate_limits) diff --git a/django_project/gap_api/models/rate_limiter.py b/django_project/gap_api/models/rate_limiter.py index fa6f8dd..4767fee 100644 --- a/django_project/gap_api/models/rate_limiter.py +++ b/django_project/gap_api/models/rate_limiter.py @@ -8,14 +8,15 @@ from django.db import models from django.conf import settings from django.core.cache import cache -from django.db.models.signals import post_save +from django.db.models.signals import post_save, pre_delete from django.dispatch import receiver class APIRateLimiter(models.Model): """Models that stores GAP API rate limiter.""" - GLOBAL_CACHE_KEY = 'gap-api-ratelimit-global' + CACHE_PREFIX_KEY = 'gap-api-ratelimit-' + GLOBAL_CACHE_KEY = f'{CACHE_PREFIX_KEY}global' user = models.ForeignKey( settings.AUTH_USER_MODEL, @@ -38,7 +39,7 @@ def config_name(self): def cache_key(self): """Return cache key for this config.""" if self.user: - return f'gap-api-ratelimit-{self.user.id}' + return f'{APIRateLimiter.CACHE_PREFIX_KEY}{self.user.id}' return APIRateLimiter.GLOBAL_CACHE_KEY @property @@ -92,7 +93,7 @@ def get_global_config(): @staticmethod def get_config(user): """Return config for given user.""" - cache_key = f'gap-api-ratelimit-{user.id}' + cache_key = f'{APIRateLimiter.CACHE_PREFIX_KEY}{user.id}' config_cache = cache.get(cache_key, None) if config_cache == 'global': @@ -122,7 +123,11 @@ def get_config(user): def ratelimiter_post_create( sender, instance: APIRateLimiter, created, *args, **kwargs): """Clear cache after saving the object.""" - if created: - return + cache.delete(instance.cache_key) + +@receiver(pre_delete, sender=APIRateLimiter) +def ratelimiter_pre_delete( + sender, instance: APIRateLimiter, *args, **kwargs): + """Clear cache before the model is deleted.""" cache.delete(instance.cache_key) From 626215724b9c4e23da6d868cfadfc32d32c0007a Mon Sep 17 00:00:00 2001 From: Danang Massandy Date: Wed, 30 Oct 2024 20:06:44 +0000 Subject: [PATCH 10/13] fix config empty --- django_project/gap_api/mixins/rate_limiter.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/django_project/gap_api/mixins/rate_limiter.py b/django_project/gap_api/mixins/rate_limiter.py index f07e0b8..283e2bb 100644 --- a/django_project/gap_api/mixins/rate_limiter.py +++ b/django_project/gap_api/mixins/rate_limiter.py @@ -188,7 +188,10 @@ def _fetch_rate_limit(self, user): if user does not have config, then it will use the global config. """ config = APIRateLimiter.get_config(user) + rate_limits = {} + if config is None: + return rate_limits if config['minute'] != -1: rate_limits[RateLimitKey.RATE_LIMIT_MINUTE_KEY] = config['minute'] From cb25a43ea353a4d50f24eb8841deca9c93aeed3c Mon Sep 17 00:00:00 2001 From: Danang Massandy Date: Wed, 30 Oct 2024 20:07:31 +0000 Subject: [PATCH 11/13] fix lint --- django_project/gap_api/mixins/rate_limiter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/django_project/gap_api/mixins/rate_limiter.py b/django_project/gap_api/mixins/rate_limiter.py index 283e2bb..fb507c5 100644 --- a/django_project/gap_api/mixins/rate_limiter.py +++ b/django_project/gap_api/mixins/rate_limiter.py @@ -188,7 +188,7 @@ def _fetch_rate_limit(self, user): if user does not have config, then it will use the global config. """ config = APIRateLimiter.get_config(user) - + rate_limits = {} if config is None: return rate_limits From 6984b501f25471eacad6bf19035cc5e60c5f478f Mon Sep 17 00:00:00 2001 From: Danang Massandy Date: Wed, 30 Oct 2024 20:25:34 +0000 Subject: [PATCH 12/13] fix test using patch time --- .../gap_api/tests/test_rate_limiter.py | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/django_project/gap_api/tests/test_rate_limiter.py b/django_project/gap_api/tests/test_rate_limiter.py index d4d91aa..0a6460c 100644 --- a/django_project/gap_api/tests/test_rate_limiter.py +++ b/django_project/gap_api/tests/test_rate_limiter.py @@ -5,6 +5,7 @@ .. note:: Unit tests for User API. """ +from mock import patch from django.test import TestCase, override_settings from django.core.cache import cache from fakeredis import FakeConnection @@ -142,44 +143,54 @@ def test_cleanup_old_entries(self): self.assertFalse( self.redis_client.hexists(hour_key, current_hour - 25)) - def test_get_waiting_time_in_seconds_for_minute_limit(self): + @patch('time.time', return_value=1730319778) + def test_get_waiting_time_in_seconds_for_minute_limit(self, mock_time): """Test get waiting time for minute limit.""" rate_limiter = RateLimiter( user_id="user129", rate_limits={1: 5, 60: 10, 1440: 100}) rate_limiter.exceeding_limits = [1] wait_time = rate_limiter.get_waiting_time_in_seconds() self.assertTrue(0 < wait_time <= 60) + mock_time.assert_called_once() - def test_get_waiting_time_in_seconds_for_hour_limit(self): + @patch('time.time', return_value=1730319778) + def test_get_waiting_time_in_seconds_for_hour_limit(self, mock_time): """Test get waiting time for hour limit.""" rate_limiter = RateLimiter( user_id="user129", rate_limits={1: 5, 60: 10, 1440: 100}) rate_limiter.exceeding_limits = [60] wait_time = rate_limiter.get_waiting_time_in_seconds() self.assertTrue(0 < wait_time <= 3600) + mock_time.assert_called_once() - def test_get_waiting_time_in_seconds_for_day_limit(self): + @patch('time.time', return_value=1730319778) + def test_get_waiting_time_in_seconds_for_day_limit(self, mock_time): """Test get waiting time for day limit.""" rate_limiter = RateLimiter( user_id="user129", rate_limits={1: 5, 60: 10, 1440: 100}) rate_limiter.exceeding_limits = [1440] wait_time = rate_limiter.get_waiting_time_in_seconds() self.assertTrue(0 < wait_time <= 86400) + mock_time.assert_called_once() - def test_get_waiting_time_in_seconds_under_limit(self): + @patch('time.time', return_value=1730319778) + def test_get_waiting_time_in_seconds_under_limit(self, mock_time): """Test get waiting time when under limit.""" rate_limiter = RateLimiter( user_id="user129", rate_limits={1: 5, 60: 10, 1440: 100}) wait_time = rate_limiter.get_waiting_time_in_seconds() self.assertIsNone(wait_time) + mock_time.assert_called_once() - def test_get_waiting_time_in_seconds_multiple_limit(self): + @patch('time.time', return_value=1730319778) + def test_get_waiting_time_in_seconds_multiple_limit(self, mock_time): """Test get waiting time when multiple limits are exceeded.""" rate_limiter = RateLimiter( user_id="user129", rate_limits={1: 5, 60: 10, 1440: 100}) rate_limiter.exceeding_limits = [1, 60] wait_time = rate_limiter.get_waiting_time_in_seconds() self.assertTrue(0 < wait_time <= 3600) + mock_time.assert_called_once() @override_settings( From db761d4dc604f1e5eef752b8d8326da963325724 Mon Sep 17 00:00:00 2001 From: Danang Massandy Date: Wed, 30 Oct 2024 20:53:48 +0000 Subject: [PATCH 13/13] add throttling class to other GAP API --- django_project/core/tests/common.py | 17 ++++++- .../gap_api/api_views/crop_insight.py | 4 +- django_project/gap_api/api_views/location.py | 3 +- django_project/gap_api/api_views/user.py | 4 +- .../gap_api/tests/test_measurement_api.py | 15 ------- .../gap_api/tests/test_rate_limiter.py | 45 ++++++++++++++++++- 6 files changed, 68 insertions(+), 20 deletions(-) diff --git a/django_project/core/tests/common.py b/django_project/core/tests/common.py index bff6c82..4efea5a 100644 --- a/django_project/core/tests/common.py +++ b/django_project/core/tests/common.py @@ -5,11 +5,26 @@ .. note:: Common class for unit tests. """ -from django.test import TestCase +from fakeredis import FakeConnection +from django.test import TestCase, override_settings from rest_framework.test import APIRequestFactory + from core.factories import UserF +@override_settings( + CACHES={ + 'default': { + 'BACKEND': 'django.core.cache.backends.redis.RedisCache', + 'LOCATION': [ + 'redis://127.0.0.1:6379', + ], + 'OPTIONS': { + 'connection_class': FakeConnection + } + } + } +) class BaseAPIViewTest(TestCase): """Base class for API test.""" diff --git a/django_project/gap_api/api_views/crop_insight.py b/django_project/gap_api/api_views/crop_insight.py index 6d2cba1..06789ff 100644 --- a/django_project/gap_api/api_views/crop_insight.py +++ b/django_project/gap_api/api_views/crop_insight.py @@ -26,6 +26,7 @@ CropInsightSerializer, CropInsightGeojsonSerializer ) from gap_api.utils.helper import ApiTag +from gap_api.mixins import GAPAPILoggingMixin, CounterSlidingWindowThrottle def default_fields(): @@ -41,10 +42,11 @@ def default_fields(): return [] -class CropPlanAPI(APIView): +class CropPlanAPI(GAPAPILoggingMixin, APIView): """API class for crop plan data.""" permission_classes = [IsAuthenticated] + throttle_classes = [CounterSlidingWindowThrottle] outputs = [ 'json', 'geojson', diff --git a/django_project/gap_api/api_views/location.py b/django_project/gap_api/api_views/location.py index 21724c5..7eed2d4 100644 --- a/django_project/gap_api/api_views/location.py +++ b/django_project/gap_api/api_views/location.py @@ -26,7 +26,7 @@ from gap_api.serializers.common import APIErrorSerializer from gap_api.serializers.location import LocationSerializer from gap_api.utils.helper import ApiTag -from gap_api.mixins import GAPAPILoggingMixin +from gap_api.mixins import GAPAPILoggingMixin, CounterSlidingWindowThrottle from gap_api.utils.fiona import ( validate_shapefile_zip, validate_collection_crs, @@ -42,6 +42,7 @@ class LocationAPI(GAPAPILoggingMixin, APIView): """API class for uploading location.""" permission_classes = [IsAuthenticated] + throttle_classes = [CounterSlidingWindowThrottle] parser_classes = (MultiPartParser,) api_parameters = [ openapi.Parameter( diff --git a/django_project/gap_api/api_views/user.py b/django_project/gap_api/api_views/user.py index a9a9a0b..f4c383d 100644 --- a/django_project/gap_api/api_views/user.py +++ b/django_project/gap_api/api_views/user.py @@ -13,12 +13,14 @@ from gap_api.serializers.common import APIErrorSerializer from gap_api.serializers.user import UserInfoSerializer from gap_api.utils.helper import ApiTag +from gap_api.mixins import GAPAPILoggingMixin, CounterSlidingWindowThrottle -class UserInfo(APIView): +class UserInfo(GAPAPILoggingMixin, APIView): """API to return user info.""" permission_classes = [IsAuthenticated] + throttle_classes = [CounterSlidingWindowThrottle] @swagger_auto_schema( operation_id='user-info', diff --git a/django_project/gap_api/tests/test_measurement_api.py b/django_project/gap_api/tests/test_measurement_api.py index 2b3181b..d92b8eb 100644 --- a/django_project/gap_api/tests/test_measurement_api.py +++ b/django_project/gap_api/tests/test_measurement_api.py @@ -9,11 +9,9 @@ from typing import List, Tuple from unittest.mock import patch -from django.test import override_settings from django.contrib.gis.geos import Polygon, MultiPolygon, Point from django.urls import reverse from rest_framework.exceptions import ValidationError -from fakeredis import FakeConnection from core.tests.common import FakeResolverMatchV1, BaseAPIViewTest from gap.factories import ( @@ -60,19 +58,6 @@ def get_data_values(self) -> DatasetReaderValue: ) -@override_settings( - CACHES={ - 'default': { - 'BACKEND': 'django.core.cache.backends.redis.RedisCache', - 'LOCATION': [ - 'redis://127.0.0.1:6379', - ], - 'OPTIONS': { - 'connection_class': FakeConnection - } - } - } -) class CommonMeasurementAPITest(BaseAPIViewTest): """Common class for Measurement API Test.""" diff --git a/django_project/gap_api/tests/test_rate_limiter.py b/django_project/gap_api/tests/test_rate_limiter.py index 0a6460c..c4d8535 100644 --- a/django_project/gap_api/tests/test_rate_limiter.py +++ b/django_project/gap_api/tests/test_rate_limiter.py @@ -10,8 +10,12 @@ from django.core.cache import cache from fakeredis import FakeConnection +from gap.factories import UserF from gap_api.models.rate_limiter import APIRateLimiter -from gap_api.mixins.rate_limiter import RateLimiter +from gap_api.mixins.rate_limiter import ( + RateLimiter, + CounterSlidingWindowThrottle +) from gap_api.factories import APIRateLimiterFactory @@ -34,6 +38,7 @@ class TestRateLimiter(TestCase): def setUp(self): """Set test class.""" self.redis_client = cache._cache.get_client() + self.user = UserF.create() def test_rate_limiter_allows_request(self): """Test requests are allowed.""" @@ -192,6 +197,36 @@ def test_get_waiting_time_in_seconds_multiple_limit(self, mock_time): self.assertTrue(0 < wait_time <= 3600) mock_time.assert_called_once() + @patch('gap_api.models.rate_limiter.APIRateLimiter.get_config') + def test_fetch_rate_limit(self, mock_get_config): + """Test fetch_rate_limit.""" + throttle = CounterSlidingWindowThrottle() + mock_get_config.return_value = None + rate_limits = throttle._fetch_rate_limit(self.user) + self.assertEqual(len(rate_limits), 0) + mock_get_config.assert_called_once() + mock_get_config.reset_mock() + + mock_get_config.return_value = { + 'minute': -1, + 'hour': -1, + 'day': -1, + } + rate_limits = throttle._fetch_rate_limit(self.user) + self.assertEqual(len(rate_limits), 0) + mock_get_config.assert_called_once() + mock_get_config.reset_mock() + + mock_get_config.return_value = { + 'minute': 10, + 'hour': 100, + 'day': 1000, + } + rate_limits = throttle._fetch_rate_limit(self.user) + self.assertEqual(len(rate_limits), 3) + mock_get_config.assert_called_once() + mock_get_config.reset_mock() + @override_settings( CACHES={ @@ -223,6 +258,14 @@ def setUp(self): day_limit=500, ) + def test_get_config_name(self): + """Test config_name.""" + self.assertEqual(self.global_rate_limiter.config_name, 'global') + self.assertEqual( + self.user_rate_limiter.config_name, + self.user_rate_limiter.user.username + ) + def test_set_cache_for_user(self): """Test setting cache for a specific user.""" # Set cache for user-specific config