Skip to content

Commit 425112f

Browse files
committed
OpenConceptLab/ocl_issues#2121 | API Throttling framework
1 parent 9027eae commit 425112f

File tree

24 files changed

+444
-6
lines changed

24 files changed

+444
-6
lines changed

core/client_configs/views.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from core.client_configs.serializers import ClientConfigSerializer, ClientConfigTemplateSerializer
88
from core.common.views import BaseAPIView
99
from .models import ClientConfig
10+
from ..common.throttling import ThrottleUtil
1011

1112

1213
class ClientConfigBaseView(generics.GenericAPIView):
@@ -15,6 +16,9 @@ class ClientConfigBaseView(generics.GenericAPIView):
1516
queryset = ClientConfig.objects.filter(is_active=True)
1617
serializer_class = ClientConfigSerializer
1718

19+
def get_throttles(self):
20+
return ThrottleUtil.get_throttles_by_user_plan(self.request.user)
21+
1822

1923
class ClientConfigView(ClientConfigBaseView, RetrieveAPIView, UpdateAPIView, DestroyAPIView):
2024
def perform_destroy(self, instance: ClientConfig):
@@ -50,6 +54,9 @@ def post(self, request, *args, **kwargs): # pylint: disable=unused-argument
5054
class ResourceTemplatesView(ListAPIView, CreateAPIView):
5155
serializer_class = ClientConfigTemplateSerializer
5256

57+
def get_throttles(self):
58+
return ThrottleUtil.get_throttles_by_user_plan(self.request.user)
59+
5360
def get_queryset(self):
5461
user = self.request.user
5562
return ClientConfig.objects.filter(

core/collections/views.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
canonical_url_param
5656
from core.common.tasks import add_references, export_collection, delete_collection, index_expansion_concepts, \
5757
index_expansion_mappings
58+
from core.common.throttling import ThrottleUtil
5859
from core.common.utils import compact_dict_by_values, parse_boolean_query_param
5960
from core.common.views import BaseAPIView, BaseLogoView, ConceptContainerExtraRetrieveUpdateDestroyView
6061
from core.concepts.documents import ConceptDocument
@@ -1103,6 +1104,9 @@ class CollectionClientConfigsView(CollectionBaseView, ResourceClientConfigsView)
11031104
class ReferenceExpressionResolveView(APIView):
11041105
serializer_class = ReferenceExpressionResolveSerializer
11051106

1107+
def get_throttles(self):
1108+
return ThrottleUtil.get_throttles_by_user_plan(self.request.user)
1109+
11061110
def get_results(self):
11071111
data = self.request.data
11081112
if not isinstance(data, list):

core/common/signals.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from django.db.models.signals import pre_save, post_save
22
from django.dispatch import receiver
3+
from pydash import get
34

45
from core.common.models import BaseModel
56
from core.orgs.models import Organization
6-
from core.users.models import UserProfile
7+
from core.users.models import UserProfile, UserRateLimit
78

89

910
@receiver(pre_save)
@@ -15,6 +16,8 @@ def stamp_uri(sender, instance, **kwargs): # pylint: disable=unused-argument
1516
@receiver(post_save, sender=Organization)
1617
@receiver(post_save, sender=UserProfile)
1718
def propagate_owner_status(sender, instance=None, created=False, **kwargs): # pylint: disable=unused-argument
19+
if created and instance.__class__ == UserProfile and not get(instance, 'api_rate_limit'):
20+
UserRateLimit(user=instance).save()
1821
if created and instance.__class__ == Organization and instance.id != 1:
1922
instance.record_create_event()
2023
if not created and instance:

core/common/throttling.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from django.conf import settings
2+
from pydash import get
3+
from rest_framework.throttling import UserRateThrottle
4+
5+
6+
class GuestMinuteThrottle(UserRateThrottle):
7+
scope = 'guest_minute'
8+
9+
10+
class GuestDayThrottle(UserRateThrottle):
11+
scope = 'guest_day'
12+
13+
14+
class LiteMinuteThrottle(UserRateThrottle):
15+
scope = 'lite_minute'
16+
17+
18+
class LiteDayThrottle(UserRateThrottle):
19+
scope = 'lite_day'
20+
21+
22+
class PremiumDayThrottle(UserRateThrottle):
23+
scope = 'premium_day'
24+
25+
26+
class PremiumMinuteThrottle(UserRateThrottle):
27+
scope = 'premium_minute'
28+
29+
30+
class ThrottleUtil:
31+
@staticmethod
32+
def get_limit_remaining(throttle, request, view):
33+
key = throttle.get_cache_key(request, view)
34+
if key is not None:
35+
history = throttle.cache.get(key, None)
36+
if history is None:
37+
return None
38+
while history and history[-1] <= throttle.timer() - throttle.duration:
39+
history.pop()
40+
remaining = throttle.num_requests - len(history)
41+
else:
42+
remaining = 'unlimited'
43+
44+
return remaining
45+
46+
@staticmethod
47+
def get_throttles_by_user_plan(user):
48+
if not settings.ENABLE_THROTTLING:
49+
return []
50+
# order is important, first one has to be minute throttle
51+
if get(user, 'api_rate_limit.is_premium'):
52+
return [PremiumMinuteThrottle(), PremiumDayThrottle()]
53+
if get(user, 'api_rate_limit.is_guest') or not get(user, 'is_authenticated'):
54+
return [GuestMinuteThrottle(), GuestDayThrottle()]
55+
return [LiteMinuteThrottle(), LiteDayThrottle()]

core/common/views.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from core.common.search import CustomESSearch
3131
from core.common.serializers import RootSerializer
3232
from core.common.swagger_parameters import all_resource_query_param
33+
from core.common.throttling import ThrottleUtil
3334
from core.common.utils import compact_dict_by_values, to_snake_case, parse_updated_since_param, \
3435
to_int, get_falsy_values, get_truthy_values, format_url_for_search
3536
from core.concepts.permissions import CanViewParentDictionary, CanEditParentDictionary
@@ -59,6 +60,9 @@ class BaseAPIView(generics.GenericAPIView, PathWalkerMixin):
5960
facet_class = None
6061
total_count = 0
6162

63+
def get_throttles(self):
64+
return ThrottleUtil.get_throttles_by_user_plan(self.request.user)
65+
6266
def has_no_kwargs(self):
6367
return len(self.kwargs.values()) == 0
6468

@@ -913,6 +917,9 @@ def __set_params(self):
913917
class SourceChildExtrasBaseView:
914918
default_qs_sort_attr = '-created_at'
915919

920+
def get_throttles(self):
921+
return ThrottleUtil.get_throttles_by_user_plan(self.request.user)
922+
916923
def get_object(self):
917924
queryset = self.get_queryset()
918925

@@ -980,6 +987,9 @@ class APIVersionView(APIView): # pragma: no cover
980987
permission_classes = (AllowAny,)
981988
swagger_schema = None
982989

990+
def get_throttles(self):
991+
return ThrottleUtil.get_throttles_by_user_plan(self.request.user)
992+
983993
@staticmethod
984994
def get(_):
985995
return Response(__version__)
@@ -989,6 +999,9 @@ class ChangeLogView(APIView): # pragma: no cover
989999
permission_classes = (AllowAny, )
9901000
swagger_schema = None
9911001

1002+
def get_throttles(self):
1003+
return ThrottleUtil.get_throttles_by_user_plan(self.request.user)
1004+
9921005
@staticmethod
9931006
def get(_):
9941007
resp = requests.get('https://raw.githubusercontent.com/OpenConceptLab/oclapi2/master/changelog.md')
@@ -1043,6 +1056,9 @@ def post(self, request, *args, **kwargs): # pylint: disable=unused-argument
10431056
class FeedbackView(APIView): # pragma: no cover
10441057
permission_classes = (AllowAny, )
10451058

1059+
def get_throttles(self):
1060+
return ThrottleUtil.get_throttles_by_user_plan(self.request.user)
1061+
10461062
@staticmethod
10471063
@swagger_auto_schema(request_body=openapi.Schema(
10481064
type=openapi.TYPE_OBJECT,
@@ -1141,6 +1157,9 @@ class AbstractChecksumView(APIView):
11411157
permission_classes = (IsAuthenticated,)
11421158
smart = False
11431159

1160+
def get_throttles(self):
1161+
return ThrottleUtil.get_throttles_by_user_plan(self.request.user)
1162+
11441163
@swagger_auto_schema(
11451164
manual_parameters=[all_resource_query_param],
11461165
request_body=openapi.Schema(

core/concepts/views.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
cascade_levels_param, cascade_direction_param, cascade_view_hierarchy, return_map_types_param,
2828
omit_if_exists_in_param, equivalency_map_types_param, search_from_latest_repo_header)
2929
from core.common.tasks import delete_concept, make_hierarchy
30+
from core.common.throttling import ThrottleUtil
3031
from core.common.utils import to_parent_uri_from_kwargs, generate_temp_version, get_truthy_values, to_int
3132
from core.common.views import SourceChildCommonBaseView, SourceChildExtrasView, \
3233
SourceChildExtraRetrieveUpdateDestroyView, BaseAPIView
@@ -752,6 +753,9 @@ class ConceptsHierarchyAmendAdminView(APIView): # pragma: no cover
752753
swagger_schema = None
753754
permission_classes = (IsAdminUser, )
754755

756+
def get_throttles(self):
757+
return ThrottleUtil.get_throttles_by_user_plan(self.request.user)
758+
755759
@staticmethod
756760
def post(request):
757761
concept_map = request.data

core/fhir/views.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,17 @@
55
from rest_framework.views import APIView
66

77
from core import settings
8+
from core.common.throttling import ThrottleUtil
89

910
logger = logging.getLogger('oclapi')
1011

1112

1213
class CapabilityStatementView(APIView):
1314
permission_classes = (AllowAny,)
1415

16+
def get_throttles(self):
17+
return ThrottleUtil.get_throttles_by_user_plan(self.request.user)
18+
1519
def get(self, request):
1620
mode = request.query_params.get('mode')
1721
fhir_base_url = f"{settings.API_BASE_URL}/fhir"

core/importers/views.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from rest_framework.views import APIView
2121

2222
from core.common.constants import DEPRECATED_API_HEADER
23+
from core.common.throttling import ThrottleUtil
2324
from core.common.views import BaseAPIView
2425
from core.common.tasks import bulk_import_new
2526
from core.common.swagger_parameters import update_if_exists_param, task_param, result_param, username_param, \
@@ -71,6 +72,9 @@ def import_response(request, import_queue, data, threads=None, inline=False, dep
7172

7273

7374
class ImportRetrieveDestroyMixin(BaseAPIView):
75+
def get_throttles(self):
76+
return ThrottleUtil.get_throttles_by_user_plan(self.request.user)
77+
7478
def get_serializer_class(self):
7579
if self.request.GET.get('task'):
7680
return TaskDetailSerializer
@@ -141,6 +145,9 @@ class BulkImportParallelInlineView(APIView):
141145
permission_classes = (IsAuthenticated, )
142146
deprecated = True
143147

148+
def get_throttles(self):
149+
return ThrottleUtil.get_throttles_by_user_plan(self.request.user)
150+
144151
def get_parsers(self):
145152
if 'application/json' in [self.request.META.get('CONTENT_TYPE')]:
146153
return [JSONParser()]
@@ -225,6 +232,9 @@ class BulkImportFileUploadView(APIView): # pragma: no cover
225232
parser_classes = (MultiPartParser, )
226233
deprecated = True
227234

235+
def get_throttles(self):
236+
return ThrottleUtil.get_throttles_by_user_plan(self.request.user)
237+
228238
@swagger_auto_schema(
229239
manual_parameters=[update_if_exists_param, file_upload_param],
230240
deprecated=True
@@ -255,6 +265,9 @@ class BulkImportFileURLView(APIView): # pragma: no cover
255265
parser_classes = (MultiPartParser, )
256266
deprecated = True
257267

268+
def get_throttles(self):
269+
return ThrottleUtil.get_throttles_by_user_plan(self.request.user)
270+
258271
@swagger_auto_schema(
259272
manual_parameters=[update_if_exists_param, file_url_param],
260273
deprecated=True
@@ -336,6 +349,9 @@ class BulkImportInlineView(APIView): # pragma: no cover
336349
parser_classes = (MultiPartParser, FormParser)
337350
deprecated = True
338351

352+
def get_throttles(self):
353+
return ThrottleUtil.get_throttles_by_user_plan(self.request.user)
354+
339355
@swagger_auto_schema(
340356
manual_parameters=[update_if_exists_param, file_url_param, file_upload_param],
341357
deprecated=True

core/indexes/views.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from core.common.swagger_parameters import apps_param, ids_param, resources_body_param, uri_param, filter_param
1111
from core.common.tasks import rebuild_indexes, populate_indexes, batch_index_resources
12+
from core.common.throttling import ThrottleUtil
1213
from core.common.utils import get_resource_class_from_resource_name
1314
from core.tasks.models import Task
1415

@@ -18,6 +19,9 @@ class BaseESIndexView(APIView): # pragma: no cover
1819
parser_classes = (MultiPartParser,)
1920
task = None
2021

22+
def get_throttles(self):
23+
return ThrottleUtil.get_throttles_by_user_plan(self.request.user)
24+
2125
@swagger_auto_schema(manual_parameters=[apps_param])
2226
def post(self, request):
2327
apps = request.data.get('apps', None)
@@ -49,6 +53,9 @@ class ResourceIndexView(APIView):
4953
permission_classes = (IsAdminUser,)
5054
parser_classes = (MultiPartParser,)
5155

56+
def get_throttles(self):
57+
return ThrottleUtil.get_throttles_by_user_plan(self.request.user)
58+
5259
@swagger_auto_schema(manual_parameters=[ids_param, uri_param, filter_param, resources_body_param])
5360
def post(self, _, resource):
5461
model = get_resource_class_from_resource_name(resource)

core/integration_tests/tests_users.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -785,6 +785,76 @@ def test_put_204(self):
785785
self.assertTrue(inactive_user.is_active)
786786

787787

788+
class UserRateLimitViewTest(OCLAPITestCase):
789+
def setUp(self):
790+
super().setUp()
791+
self.superuser = UserProfile.objects.get(username='ocladmin')
792+
793+
def test_put_bad_request(self):
794+
user = UserProfileFactory()
795+
self.assertEqual(user.api_rate_limit.rate_plan, 'lite')
796+
797+
response = self.client.put(
798+
f'/users/{user.username}/rate-limit/',
799+
{},
800+
HTTP_AUTHORIZATION='Token ' + self.superuser.get_token(),
801+
format='json'
802+
)
803+
self.assertEqual(response.status_code, 400)
804+
self.assertEqual(
805+
response.data,
806+
{'detail': ErrorDetail(string='"rate_plan" needs to be one of "guest", "lite" or "premium"', code='bad_request')}
807+
)
808+
809+
user.refresh_from_db()
810+
self.assertEqual(user.api_rate_limit.rate_plan, 'lite')
811+
812+
response = self.client.put(
813+
f'/users/{user.username}/rate-limit/',
814+
{'rate_plan': 'blah'},
815+
HTTP_AUTHORIZATION='Token ' + self.superuser.get_token(),
816+
format='json'
817+
)
818+
self.assertEqual(response.status_code, 400)
819+
self.assertEqual(response.data, {'rate_plan': ["Value 'blah' is not a valid choice."]})
820+
821+
self.assertEqual(user.api_rate_limit.rate_plan, 'lite')
822+
823+
response = self.client.put(
824+
f'/users/{user.username}/rate-limit/',
825+
{'rate_plan': 'blah'},
826+
HTTP_AUTHORIZATION='Token ' + user.get_token(),
827+
format='json'
828+
)
829+
self.assertEqual(response.status_code, 403)
830+
831+
def test_put_204(self):
832+
user = UserProfileFactory()
833+
self.assertEqual(user.api_rate_limit.rate_plan, 'lite')
834+
835+
response = self.client.put(
836+
f'/users/{user.username}/rate-limit/',
837+
{'rate_plan': 'guest'},
838+
HTTP_AUTHORIZATION='Token ' + self.superuser.get_token(),
839+
format='json'
840+
)
841+
self.assertEqual(response.status_code, 204)
842+
843+
user.refresh_from_db()
844+
self.assertEqual(user.api_rate_limit.rate_plan, 'guest')
845+
846+
response = self.client.put(
847+
f'/users/{user.username}/rate-limit/',
848+
{'rate_plan': 'premium'},
849+
HTTP_AUTHORIZATION='Token ' + self.superuser.get_token(),
850+
format='json'
851+
)
852+
self.assertEqual(response.status_code, 204)
853+
854+
user.refresh_from_db()
855+
self.assertEqual(user.api_rate_limit.rate_plan, 'premium')
856+
857+
788858
class UserStaffToggleViewTest(OCLAPITestCase):
789859
def setUp(self):
790860
super().setUp()

0 commit comments

Comments
 (0)