diff options
Diffstat (limited to 'pydis_site/apps/api')
-rw-r--r-- | pydis_site/apps/api/serializers.py | 98 | ||||
-rw-r--r-- | pydis_site/apps/api/tests/test_users.py | 221 | ||||
-rw-r--r-- | pydis_site/apps/api/viewsets/bot/user.py | 116 |
3 files changed, 422 insertions, 13 deletions
diff --git a/pydis_site/apps/api/serializers.py b/pydis_site/apps/api/serializers.py index f9a5517e..25c5c82e 100644 --- a/pydis_site/apps/api/serializers.py +++ b/pydis_site/apps/api/serializers.py @@ -1,7 +1,16 @@ """Converters from Django models to data interchange formats and back.""" -from rest_framework.serializers import ModelSerializer, PrimaryKeyRelatedField, ValidationError +from django.db.models.query import QuerySet +from django.db.utils import IntegrityError +from rest_framework.exceptions import NotFound +from rest_framework.serializers import ( + IntegerField, + ListSerializer, + ModelSerializer, + PrimaryKeyRelatedField, + ValidationError +) +from rest_framework.settings import api_settings from rest_framework.validators import UniqueTogetherValidator -from rest_framework_bulk import BulkSerializerMixin from .models import ( BotSetting, @@ -235,15 +244,98 @@ class RoleSerializer(ModelSerializer): fields = ('id', 'name', 'colour', 'permissions', 'position') -class UserSerializer(BulkSerializerMixin, ModelSerializer): +class UserListSerializer(ListSerializer): + """List serializer for User model to handle bulk updates.""" + + def create(self, validated_data: list) -> list: + """Override create method to optimize django queries.""" + new_users = [] + seen = set() + + for user_dict in validated_data: + if user_dict["id"] in seen: + raise ValidationError( + {"id": [f"User with ID {user_dict['id']} given multiple times."]} + ) + seen.add(user_dict["id"]) + new_users.append(User(**user_dict)) + + User.objects.bulk_create(new_users, ignore_conflicts=True) + return [] + + def update(self, queryset: QuerySet, validated_data: list) -> list: + """ + Override update method to support bulk updates. + + ref:https://www.django-rest-framework.org/api-guide/serializers/#customizing-multiple-update + """ + object_ids = set() + + for data in validated_data: + try: + if data["id"] in object_ids: + # If request data contains users with same ID. + raise ValidationError( + {"id": [f"User with ID {data['id']} given multiple times."]} + ) + except KeyError: + # If user ID not provided in request body. + raise ValidationError( + {"id": ["This field is required."]} + ) + object_ids.add(data["id"]) + + # filter queryset + filtered_instances = queryset.filter(id__in=object_ids) + + instance_mapping = {user.id: user for user in filtered_instances} + + updated = [] + fields_to_update = set() + for user_data in validated_data: + for key in user_data: + fields_to_update.add(key) + + try: + user = instance_mapping[user_data["id"]] + except KeyError: + raise NotFound({"detail": f"User with id {user_data['id']} not found."}) + + user.__dict__.update(user_data) + updated.append(user) + + fields_to_update.remove("id") + + if not fields_to_update: + # Raise ValidationError when only id field is given. + raise ValidationError( + {api_settings.NON_FIELD_ERRORS_KEY: ["Insufficient data provided."]} + ) + + User.objects.bulk_update(updated, fields_to_update) + return updated + + +class UserSerializer(ModelSerializer): """A class providing (de-)serialization of `User` instances.""" + # ID field must be explicitly set as the default id field is read-only. + id = IntegerField(min_value=0) + class Meta: """Metadata defined for the Django REST Framework.""" model = User fields = ('id', 'name', 'discriminator', 'roles', 'in_guild') depth = 1 + list_serializer_class = UserListSerializer + + def create(self, validated_data: dict) -> User: + """Override create method to catch IntegrityError.""" + try: + return super().create(validated_data) + except IntegrityError: + raise ValidationError({"id": ["User with ID already present."]}) class NominationSerializer(ModelSerializer): diff --git a/pydis_site/apps/api/tests/test_users.py b/pydis_site/apps/api/tests/test_users.py index 76a21d3a..d03785ae 100644 --- a/pydis_site/apps/api/tests/test_users.py +++ b/pydis_site/apps/api/tests/test_users.py @@ -48,6 +48,13 @@ class CreationTests(APISubdomainTestCase): position=1 ) + cls.user = User.objects.create( + id=11, + name="Name doesn't matter.", + discriminator=1122, + in_guild=True + ) + def test_accepts_valid_data(self): url = reverse('bot:user-list', host='api') data = { @@ -92,7 +99,7 @@ class CreationTests(APISubdomainTestCase): response = self.client.post(url, data=data) self.assertEqual(response.status_code, 201) - self.assertEqual(response.json(), data) + self.assertEqual(response.json(), []) def test_returns_400_for_unknown_role_id(self): url = reverse('bot:user-list', host='api') @@ -118,6 +125,176 @@ class CreationTests(APISubdomainTestCase): response = self.client.post(url, data=data) self.assertEqual(response.status_code, 400) + def test_returns_400_for_user_recreation(self): + """Return 201 if User is already present in database as it skips User creation.""" + url = reverse('bot:user-list', host='api') + data = [{ + 'id': 11, + 'name': 'You saw nothing.', + 'discriminator': 112, + 'in_guild': True + }] + response = self.client.post(url, data=data) + self.assertEqual(response.status_code, 201) + + def test_returns_400_for_duplicate_request_users(self): + """Return 400 if 2 Users with same ID is passed in the request data.""" + url = reverse('bot:user-list', host='api') + data = [ + { + 'id': 11, + 'name': 'You saw nothing.', + 'discriminator': 112, + 'in_guild': True + }, + { + 'id': 11, + 'name': 'You saw nothing part 2.', + 'discriminator': 1122, + 'in_guild': False + } + ] + response = self.client.post(url, data=data) + self.assertEqual(response.status_code, 400) + + def test_returns_400_for_existing_user(self): + """Returns 400 if user is already present in DB.""" + url = reverse('bot:user-list', host='api') + data = { + 'id': 11, + 'name': 'You saw nothing part 3.', + 'discriminator': 1122, + 'in_guild': True + } + response = self.client.post(url, data=data) + self.assertEqual(response.status_code, 400) + + +class MultiPatchTests(APISubdomainTestCase): + @classmethod + def setUpTestData(cls): + cls.role_developer = Role.objects.create( + id=159, + name="Developer", + colour=2, + permissions=0b01010010101, + position=10, + ) + cls.user_1 = User.objects.create( + id=1, + name="Patch test user 1.", + discriminator=1111, + in_guild=True + ) + cls.user_2 = User.objects.create( + id=2, + name="Patch test user 2.", + discriminator=2222, + in_guild=True + ) + + def test_multiple_users_patch(self): + url = reverse("bot:user-bulk-patch", host="api") + data = [ + { + "id": 1, + "name": "User 1 patched!", + "discriminator": 1010, + "roles": [self.role_developer.id], + "in_guild": False + }, + { + "id": 2, + "name": "User 2 patched!" + } + ] + + response = self.client.patch(url, data=data) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json()[0], data[0]) + + user_2 = User.objects.get(id=2) + self.assertEqual(user_2.name, data[1]["name"]) + + def test_returns_400_for_missing_user_id(self): + url = reverse("bot:user-bulk-patch", host="api") + data = [ + { + "name": "I am ghost user!", + "discriminator": 1010, + "roles": [self.role_developer.id], + "in_guild": False + }, + { + "name": "patch me? whats my id?" + } + ] + response = self.client.patch(url, data=data) + self.assertEqual(response.status_code, 400) + + def test_returns_404_for_not_found_user(self): + url = reverse("bot:user-bulk-patch", host="api") + data = [ + { + "id": 1, + "name": "User 1 patched again!!!", + "discriminator": 1010, + "roles": [self.role_developer.id], + "in_guild": False + }, + { + "id": 22503405, + "name": "User unknown not patched!" + } + ] + response = self.client.patch(url, data=data) + self.assertEqual(response.status_code, 404) + + def test_returns_400_for_bad_data(self): + url = reverse("bot:user-bulk-patch", host="api") + data = [ + { + "id": 1, + "in_guild": "Catch me!" + }, + { + "id": 2, + "discriminator": "find me!" + } + ] + + response = self.client.patch(url, data=data) + self.assertEqual(response.status_code, 400) + + def test_returns_400_for_insufficient_data(self): + url = reverse("bot:user-bulk-patch", host="api") + data = [ + { + "id": 1, + }, + { + "id": 2, + } + ] + response = self.client.patch(url, data=data) + self.assertEqual(response.status_code, 400) + + def test_returns_400_for_duplicate_request_users(self): + """Return 400 if 2 Users with same ID is passed in the request data.""" + url = reverse("bot:user-bulk-patch", host="api") + data = [ + { + 'id': 1, + 'name': 'You saw nothing.', + }, + { + 'id': 1, + 'name': 'You saw nothing part 2.', + } + ] + response = self.client.patch(url, data=data) + self.assertEqual(response.status_code, 400) + class UserModelTests(APISubdomainTestCase): @classmethod @@ -175,6 +352,48 @@ class UserModelTests(APISubdomainTestCase): self.assertEqual(self.user_with_roles.username, "Test User with two roles#0001") +class UserPaginatorTests(APISubdomainTestCase): + @classmethod + def setUpTestData(cls): + users = [] + for i in range(1, 10_001): + users.append(User( + id=i, + name=f"user{i}", + discriminator=1111, + in_guild=True + )) + cls.users = User.objects.bulk_create(users) + + def test_returns_single_page_response(self): + url = reverse("bot:user-list", host="api") + response = self.client.get(url).json() + self.assertIsNone(response["next_page_no"]) + self.assertIsNone(response["previous_page_no"]) + + def test_returns_next_page_number(self): + User.objects.create( + id=10_001, + name="user10001", + discriminator=1111, + in_guild=True + ) + url = reverse("bot:user-list", host="api") + response = self.client.get(url).json() + self.assertEqual(2, response["next_page_no"]) + + def test_returns_previous_page_number(self): + User.objects.create( + id=10_001, + name="user10001", + discriminator=1111, + in_guild=True + ) + url = reverse("bot:user-list", host="api") + response = self.client.get(url, {"page": 2}).json() + self.assertEqual(1, response["previous_page_no"]) + + class UserMetricityTests(APISubdomainTestCase): @classmethod def setUpTestData(cls): diff --git a/pydis_site/apps/api/viewsets/bot/user.py b/pydis_site/apps/api/viewsets/bot/user.py index 352d77c0..3ab71186 100644 --- a/pydis_site/apps/api/viewsets/bot/user.py +++ b/pydis_site/apps/api/viewsets/bot/user.py @@ -1,26 +1,65 @@ +import typing +from collections import OrderedDict + from rest_framework import status from rest_framework.decorators import action +from rest_framework.pagination import PageNumberPagination from rest_framework.request import Request from rest_framework.response import Response +from rest_framework.serializers import ModelSerializer from rest_framework.viewsets import ModelViewSet -from rest_framework_bulk import BulkCreateModelMixin from pydis_site.apps.api.models.bot.metricity import Metricity, NotFound from pydis_site.apps.api.models.bot.user import User from pydis_site.apps.api.serializers import UserSerializer -class UserViewSet(BulkCreateModelMixin, ModelViewSet): +class UserListPagination(PageNumberPagination): + """Custom pagination class for the User Model.""" + + page_size = 10000 + page_size_query_param = "page_size" + + def get_next_page_number(self) -> typing.Optional[int]: + """Get the next page number.""" + if not self.page.has_next(): + return None + page_number = self.page.next_page_number() + return page_number + + def get_previous_page_number(self) -> typing.Optional[int]: + """Get the previous page number.""" + if not self.page.has_previous(): + return None + + page_number = self.page.previous_page_number() + return page_number + + def get_paginated_response(self, data: list) -> Response: + """Override method to send modified response.""" + return Response(OrderedDict([ + ('count', self.page.paginator.count), + ('next_page_no', self.get_next_page_number()), + ('previous_page_no', self.get_previous_page_number()), + ('results', data) + ])) + + +class UserViewSet(ModelViewSet): """ View providing CRUD operations on Discord users through the bot. ## Routes ### GET /bot/users - Returns all users currently known. + Returns all users currently known with pagination. #### Response format - >>> [ - ... { + >>> { + ... 'count': 95000, + ... 'next_page_no': "2", + ... 'previous_page_no': None, + ... 'results': [ + ... { ... 'id': 409107086526644234, ... 'name': "Python", ... 'discriminator': 4329, @@ -31,8 +70,13 @@ class UserViewSet(BulkCreateModelMixin, ModelViewSet): ... 458226699344019457 ... ], ... 'in_guild': True - ... } - ... ] + ... }, + ... ] + ... } + + #### Optional Query Parameters + - page_size: number of Users in one page, defaults to 10,000 + - page: page number #### Status codes - 200: returned on success @@ -74,6 +118,7 @@ class UserViewSet(BulkCreateModelMixin, ModelViewSet): ### POST /bot/users Adds a single or multiple new users. The roles attached to the user(s) must be roles known by the site. + Users that already exist in the database will be skipped. #### Request body >>> { @@ -85,11 +130,13 @@ class UserViewSet(BulkCreateModelMixin, ModelViewSet): ... } Alternatively, request users can be POSTed as a list of above objects, - in which case multiple users will be created at once. + in which case multiple users will be created at once. In this case, + the response is an empty list. #### Status codes - 201: returned on success - 400: if one of the given roles does not exist, or one of the given fields is invalid + - 400: if multiple user objects with the same id are given ### PUT /bot/users/<snowflake:int> Update the user with the given `snowflake`. @@ -127,6 +174,34 @@ class UserViewSet(BulkCreateModelMixin, ModelViewSet): - 400: if the request body was invalid, see response body for details - 404: if the user with the given `snowflake` could not be found + ### BULK PATCH /bot/users/bulk_patch + Update users with the given `ids` and `details`. + `id` field and at least one other field is mandatory. + + #### Request body + >>> [ + ... { + ... 'id': int, + ... 'name': str, + ... 'discriminator': int, + ... 'roles': List[int], + ... 'in_guild': bool + ... }, + ... { + ... 'id': int, + ... 'name': str, + ... 'discriminator': int, + ... 'roles': List[int], + ... 'in_guild': bool + ... }, + ... ] + + #### Status codes + - 200: returned on success + - 400: if the request body was invalid, see response body for details + - 400: if multiple user objects with the same id are given + - 404: if the user with the given id does not exist + ### DELETE /bot/users/<snowflake:int> Deletes the user with the given `snowflake`. @@ -136,7 +211,30 @@ class UserViewSet(BulkCreateModelMixin, ModelViewSet): """ serializer_class = UserSerializer - queryset = User.objects + queryset = User.objects.all().order_by("id") + pagination_class = UserListPagination + + def get_serializer(self, *args, **kwargs) -> ModelSerializer: + """Set Serializer many attribute to True if request body contains a list.""" + if isinstance(kwargs.get('data', {}), list): + kwargs['many'] = True + + return super().get_serializer(*args, **kwargs) + + @action(detail=False, methods=["PATCH"], name='user-bulk-patch') + def bulk_patch(self, request: Request) -> Response: + """Update multiple User objects in a single request.""" + serializer = self.get_serializer( + instance=self.get_queryset(), + data=request.data, + many=True, + partial=True + ) + + serializer.is_valid(raise_exception=True) + serializer.save() + + return Response(serializer.data, status=status.HTTP_200_OK) @action(detail=True) def metricity_data(self, request: Request, pk: str = None) -> Response: |