From 830c0baab4ba925fdcaa859c87c4c8ce893c9ea1 Mon Sep 17 00:00:00 2001 From: FrankApiyo Date: Thu, 22 Aug 2024 18:06:28 +0300 Subject: [PATCH] Add throttling based on URL Do not consider user or user-agent or headers --- onadata/libs/tests/test_throttle.py | 37 +++++++++++++++++++++++-- onadata/libs/throttle.py | 32 +++++++++++++++++++++ onadata/settings/github_actions_test.py | 2 ++ 3 files changed, 69 insertions(+), 2 deletions(-) diff --git a/onadata/libs/tests/test_throttle.py b/onadata/libs/tests/test_throttle.py index c7b2c2a8e0..ffb64e7aaf 100644 --- a/onadata/libs/tests/test_throttle.py +++ b/onadata/libs/tests/test_throttle.py @@ -3,10 +3,43 @@ from rest_framework.test import APIRequestFactory -from onadata.libs.throttle import RequestHeaderThrottle +from onadata.libs.throttle import RequestHeaderThrottle, URLThrottle -class ThrottlingTests(TestCase): +class URLThrottleTests(TestCase): + """ + Test Renderer class. + """ + + def setUp(self): + """ + Reset the cache so that no throttles will be active + """ + cache.clear() + self.factory = APIRequestFactory() + self.throttle = URLThrottle() + + def test_requests_are_not_throttled_for_get(self): + request = self.factory.get("/bob/submission") + key = self.throttle.get_cache_key(request, None) + self.assertEqual(key, None) + + def test_requests_are_not_throttled_for_non_submission_urls(self): + request = self.factory.post("/projects/") + key = self.throttle.get_cache_key(request, None) + self.assertEqual(key, None) + + def test_requests_are_throttled(self): + request = self.factory.post("/bob/submission") + key = self.throttle.get_cache_key(request, None) + self.assertEqual(key, 'throttle_method_POST_path_/bob/submission') + + request = self.factory.post("/project/124/submission") + key = self.throttle.get_cache_key(request, None) + self.assertEqual(key, 'throttle_method_POST_path_/project/124/submission') + + +class RequestHeaderThrottlingTests(TestCase): """ Test Renderer class. """ diff --git a/onadata/libs/throttle.py b/onadata/libs/throttle.py index ff79442a0b..7c9ba6a603 100644 --- a/onadata/libs/throttle.py +++ b/onadata/libs/throttle.py @@ -7,6 +7,38 @@ from rest_framework.throttling import SimpleRateThrottle +class URLThrottle(SimpleRateThrottle): + + @property + def rate(self): + return getattr( + settings, + "THROTTLE_USERS_RATE", + "300/min" + ) + + @property + def throttled_users(self): + return getattr( + settings, + "THROTTLE_USERS", + [], + ) + + def get_form_owner_or_project_from_url(self, url): + path_segments = url.split("/") + if len(path_segments) > 1: + return path_segments[-2] + return None + + def get_cache_key(self, request, _): + form_owner_or_project = self.get_form_owner_or_project_from_url(request.path) + if form_owner_or_project in self.throttled_users \ + and request.method == 'POST' and '/submission' in request.path: + return f"throttle_method_{request.method}_path_{request.path}" + return None + + class RequestHeaderThrottle(SimpleRateThrottle): """ Custom Throttling class that throttles requests that match a specific diff --git a/onadata/settings/github_actions_test.py b/onadata/settings/github_actions_test.py index 68d2dd503f..49ef6bd054 100644 --- a/onadata/settings/github_actions_test.py +++ b/onadata/settings/github_actions_test.py @@ -69,3 +69,5 @@ ODK_TOKEN_FERNET_KEY = "ROsB4T8s1rCJskAdgpTQEKfH2x2K_EX_YBi3UFyoYng=" # nosec OPENID_CONNECT_PROVIDERS = {} AUTH_PASSWORD_VALIDATORS = [] + +THROTTLE_USERS=["bob", "123"]