diff options
Diffstat (limited to 'pydis_site/apps')
| -rw-r--r-- | pydis_site/apps/api/serializers.py | 98 | ||||
| -rw-r--r-- | pydis_site/apps/api/tests/test_users.py | 221 | ||||
| -rw-r--r-- | pydis_site/apps/api/viewsets/bot/user.py | 120 | 
3 files changed, 426 insertions, 13 deletions
| diff --git a/pydis_site/apps/api/serializers.py b/pydis_site/apps/api/serializers.py index f9a5517e..25c5c82e 100644 --- a/pydis_site/apps/api/serializers.py +++ b/pydis_site/apps/api/serializers.py @@ -1,7 +1,16 @@  """Converters from Django models to data interchange formats and back.""" -from rest_framework.serializers import ModelSerializer, PrimaryKeyRelatedField, ValidationError +from django.db.models.query import QuerySet +from django.db.utils import IntegrityError +from rest_framework.exceptions import NotFound +from rest_framework.serializers import ( +    IntegerField, +    ListSerializer, +    ModelSerializer, +    PrimaryKeyRelatedField, +    ValidationError +) +from rest_framework.settings import api_settings  from rest_framework.validators import UniqueTogetherValidator -from rest_framework_bulk import BulkSerializerMixin  from .models import (      BotSetting, @@ -235,15 +244,98 @@ class RoleSerializer(ModelSerializer):          fields = ('id', 'name', 'colour', 'permissions', 'position') -class UserSerializer(BulkSerializerMixin, ModelSerializer): +class UserListSerializer(ListSerializer): +    """List serializer for User model to handle bulk updates.""" + +    def create(self, validated_data: list) -> list: +        """Override create method to optimize django queries.""" +        new_users = [] +        seen = set() + +        for user_dict in validated_data: +            if user_dict["id"] in seen: +                raise ValidationError( +                    {"id": [f"User with ID {user_dict['id']} given multiple times."]} +                ) +            seen.add(user_dict["id"]) +            new_users.append(User(**user_dict)) + +        User.objects.bulk_create(new_users, ignore_conflicts=True) +        return [] + +    def update(self, queryset: QuerySet, validated_data: list) -> list: +        """ +        Override update method to support bulk updates. + +        ref:https://www.django-rest-framework.org/api-guide/serializers/#customizing-multiple-update +        """ +        object_ids = set() + +        for data in validated_data: +            try: +                if data["id"] in object_ids: +                    # If request data contains users with same ID. +                    raise ValidationError( +                        {"id": [f"User with ID {data['id']} given multiple times."]} +                    ) +            except KeyError: +                # If user ID not provided in request body. +                raise ValidationError( +                    {"id": ["This field is required."]} +                ) +            object_ids.add(data["id"]) + +        # filter queryset +        filtered_instances = queryset.filter(id__in=object_ids) + +        instance_mapping = {user.id: user for user in filtered_instances} + +        updated = [] +        fields_to_update = set() +        for user_data in validated_data: +            for key in user_data: +                fields_to_update.add(key) + +                try: +                    user = instance_mapping[user_data["id"]] +                except KeyError: +                    raise NotFound({"detail": f"User with id {user_data['id']} not found."}) + +                user.__dict__.update(user_data) +            updated.append(user) + +        fields_to_update.remove("id") + +        if not fields_to_update: +            # Raise ValidationError when only id field is given. +            raise ValidationError( +                {api_settings.NON_FIELD_ERRORS_KEY: ["Insufficient data provided."]} +            ) + +        User.objects.bulk_update(updated, fields_to_update) +        return updated + + +class UserSerializer(ModelSerializer):      """A class providing (de-)serialization of `User` instances.""" +    # ID field must be explicitly set as the default id field is read-only. +    id = IntegerField(min_value=0) +      class Meta:          """Metadata defined for the Django REST Framework."""          model = User          fields = ('id', 'name', 'discriminator', 'roles', 'in_guild')          depth = 1 +        list_serializer_class = UserListSerializer + +    def create(self, validated_data: dict) -> User: +        """Override create method to catch IntegrityError.""" +        try: +            return super().create(validated_data) +        except IntegrityError: +            raise ValidationError({"id": ["User with ID already present."]})  class NominationSerializer(ModelSerializer): diff --git a/pydis_site/apps/api/tests/test_users.py b/pydis_site/apps/api/tests/test_users.py index a02fce8a..825e4edb 100644 --- a/pydis_site/apps/api/tests/test_users.py +++ b/pydis_site/apps/api/tests/test_users.py @@ -45,6 +45,13 @@ class CreationTests(APISubdomainTestCase):              position=1          ) +        cls.user = User.objects.create( +            id=11, +            name="Name doesn't matter.", +            discriminator=1122, +            in_guild=True +        ) +      def test_accepts_valid_data(self):          url = reverse('bot:user-list', host='api')          data = { @@ -89,7 +96,7 @@ class CreationTests(APISubdomainTestCase):          response = self.client.post(url, data=data)          self.assertEqual(response.status_code, 201) -        self.assertEqual(response.json(), data) +        self.assertEqual(response.json(), [])      def test_returns_400_for_unknown_role_id(self):          url = reverse('bot:user-list', host='api') @@ -115,6 +122,176 @@ class CreationTests(APISubdomainTestCase):          response = self.client.post(url, data=data)          self.assertEqual(response.status_code, 400) +    def test_returns_400_for_user_recreation(self): +        """Return 201 if User is already present in database as it skips User creation.""" +        url = reverse('bot:user-list', host='api') +        data = [{ +            'id': 11, +            'name': 'You saw nothing.', +            'discriminator': 112, +            'in_guild': True +        }] +        response = self.client.post(url, data=data) +        self.assertEqual(response.status_code, 201) + +    def test_returns_400_for_duplicate_request_users(self): +        """Return 400 if 2 Users with same ID is passed in the request data.""" +        url = reverse('bot:user-list', host='api') +        data = [ +            { +                'id': 11, +                'name': 'You saw nothing.', +                'discriminator': 112, +                'in_guild': True +            }, +            { +                'id': 11, +                'name': 'You saw nothing part 2.', +                'discriminator': 1122, +                'in_guild': False +            } +        ] +        response = self.client.post(url, data=data) +        self.assertEqual(response.status_code, 400) + +    def test_returns_400_for_existing_user(self): +        """Returns 400 if user is already present in DB.""" +        url = reverse('bot:user-list', host='api') +        data = { +            'id': 11, +            'name': 'You saw nothing part 3.', +            'discriminator': 1122, +            'in_guild': True +        } +        response = self.client.post(url, data=data) +        self.assertEqual(response.status_code, 400) + + +class MultiPatchTests(APISubdomainTestCase): +    @classmethod +    def setUpTestData(cls): +        cls.role_developer = Role.objects.create( +            id=159, +            name="Developer", +            colour=2, +            permissions=0b01010010101, +            position=10, +        ) +        cls.user_1 = User.objects.create( +            id=1, +            name="Patch test user 1.", +            discriminator=1111, +            in_guild=True +        ) +        cls.user_2 = User.objects.create( +            id=2, +            name="Patch test user 2.", +            discriminator=2222, +            in_guild=True +        ) + +    def test_multiple_users_patch(self): +        url = reverse("bot:user-bulk-patch", host="api") +        data = [ +            { +                "id": 1, +                "name": "User 1 patched!", +                "discriminator": 1010, +                "roles": [self.role_developer.id], +                "in_guild": False +            }, +            { +                "id": 2, +                "name": "User 2 patched!" +            } +        ] + +        response = self.client.patch(url, data=data) +        self.assertEqual(response.status_code, 200) +        self.assertEqual(response.json()[0], data[0]) + +        user_2 = User.objects.get(id=2) +        self.assertEqual(user_2.name, data[1]["name"]) + +    def test_returns_400_for_missing_user_id(self): +        url = reverse("bot:user-bulk-patch", host="api") +        data = [ +            { +                "name": "I am ghost user!", +                "discriminator": 1010, +                "roles": [self.role_developer.id], +                "in_guild": False +            }, +            { +                "name": "patch me? whats my id?" +            } +        ] +        response = self.client.patch(url, data=data) +        self.assertEqual(response.status_code, 400) + +    def test_returns_404_for_not_found_user(self): +        url = reverse("bot:user-bulk-patch", host="api") +        data = [ +            { +                "id": 1, +                "name": "User 1 patched again!!!", +                "discriminator": 1010, +                "roles": [self.role_developer.id], +                "in_guild": False +            }, +            { +                "id": 22503405, +                "name": "User unknown not patched!" +            } +        ] +        response = self.client.patch(url, data=data) +        self.assertEqual(response.status_code, 404) + +    def test_returns_400_for_bad_data(self): +        url = reverse("bot:user-bulk-patch", host="api") +        data = [ +            { +                "id": 1, +                "in_guild": "Catch me!" +            }, +            { +                "id": 2, +                "discriminator": "find me!" +            } +        ] + +        response = self.client.patch(url, data=data) +        self.assertEqual(response.status_code, 400) + +    def test_returns_400_for_insufficient_data(self): +        url = reverse("bot:user-bulk-patch", host="api") +        data = [ +            { +                "id": 1, +            }, +            { +                "id": 2, +            } +        ] +        response = self.client.patch(url, data=data) +        self.assertEqual(response.status_code, 400) + +    def test_returns_400_for_duplicate_request_users(self): +        """Return 400 if 2 Users with same ID is passed in the request data.""" +        url = reverse("bot:user-bulk-patch", host="api") +        data = [ +            { +                'id': 1, +                'name': 'You saw nothing.', +            }, +            { +                'id': 1, +                'name': 'You saw nothing part 2.', +            } +        ] +        response = self.client.patch(url, data=data) +        self.assertEqual(response.status_code, 400) +  class UserModelTests(APISubdomainTestCase):      @classmethod @@ -170,3 +347,45 @@ class UserModelTests(APISubdomainTestCase):      def test_correct_username_formatting(self):          """Tests the username property with both name and discriminator formatted together."""          self.assertEqual(self.user_with_roles.username, "Test User with two roles#0001") + + +class UserPaginatorTests(APISubdomainTestCase): +    @classmethod +    def setUpTestData(cls): +        users = [] +        for i in range(1, 10_001): +            users.append(User( +                id=i, +                name=f"user{i}", +                discriminator=1111, +                in_guild=True +            )) +        cls.users = User.objects.bulk_create(users) + +    def test_returns_single_page_response(self): +        url = reverse("bot:user-list", host="api") +        response = self.client.get(url).json() +        self.assertIsNone(response["next_page_no"]) +        self.assertIsNone(response["previous_page_no"]) + +    def test_returns_next_page_number(self): +        User.objects.create( +            id=10_001, +            name="user10001", +            discriminator=1111, +            in_guild=True +        ) +        url = reverse("bot:user-list", host="api") +        response = self.client.get(url).json() +        self.assertEqual(2, response["next_page_no"]) + +    def test_returns_previous_page_number(self): +        User.objects.create( +            id=10_001, +            name="user10001", +            discriminator=1111, +            in_guild=True +        ) +        url = reverse("bot:user-list", host="api") +        response = self.client.get(url, {"page": 2}).json() +        self.assertEqual(1, response["previous_page_no"]) diff --git a/pydis_site/apps/api/viewsets/bot/user.py b/pydis_site/apps/api/viewsets/bot/user.py index 9571b3d7..3e4b627e 100644 --- a/pydis_site/apps/api/viewsets/bot/user.py +++ b/pydis_site/apps/api/viewsets/bot/user.py @@ -1,21 +1,64 @@ +import typing +from collections import OrderedDict + +from rest_framework import status +from rest_framework.decorators import action +from rest_framework.pagination import PageNumberPagination +from rest_framework.request import Request +from rest_framework.response import Response +from rest_framework.serializers import ModelSerializer  from rest_framework.viewsets import ModelViewSet -from rest_framework_bulk import BulkCreateModelMixin  from pydis_site.apps.api.models.bot.user import User  from pydis_site.apps.api.serializers import UserSerializer -class UserViewSet(BulkCreateModelMixin, ModelViewSet): +class UserListPagination(PageNumberPagination): +    """Custom pagination class for the User Model.""" + +    page_size = 10000 +    page_size_query_param = "page_size" + +    def get_next_page_number(self) -> typing.Optional[int]: +        """Get the next page number.""" +        if not self.page.has_next(): +            return None +        page_number = self.page.next_page_number() +        return page_number + +    def get_previous_page_number(self) -> typing.Optional[int]: +        """Get the previous page number.""" +        if not self.page.has_previous(): +            return None + +        page_number = self.page.previous_page_number() +        return page_number + +    def get_paginated_response(self, data: list) -> Response: +        """Override method to send modified response.""" +        return Response(OrderedDict([ +            ('count', self.page.paginator.count), +            ('next_page_no', self.get_next_page_number()), +            ('previous_page_no', self.get_previous_page_number()), +            ('results', data) +        ])) + + +class UserViewSet(ModelViewSet):      """      View providing CRUD operations on Discord users through the bot.      ## Routes      ### GET /bot/users -    Returns all users currently known. +    Returns all users currently known with pagination.      #### Response format -    >>> [ -    ...     { +    >>> { +    ...     'count': 95000, +    ...     'next_page_no': "2", +    ...     'previous_page_no': None, +    ...     'results': [ +    ...      {      ...         'id': 409107086526644234,      ...         'name': "Python",      ...         'discriminator': 4329, @@ -26,8 +69,13 @@ class UserViewSet(BulkCreateModelMixin, ModelViewSet):      ...             458226699344019457      ...         ],      ...         'in_guild': True -    ...     } -    ... ] +    ...     }, +    ...     ] +    ... } + +    #### Optional Query Parameters +    - page_size: number of Users in one page, defaults to 10,000 +    - page: page number      #### Status codes      - 200: returned on success @@ -56,6 +104,7 @@ class UserViewSet(BulkCreateModelMixin, ModelViewSet):      ### POST /bot/users      Adds a single or multiple new users.      The roles attached to the user(s) must be roles known by the site. +    Users that already exist in the database will be skipped.      #### Request body      >>> { @@ -67,11 +116,13 @@ class UserViewSet(BulkCreateModelMixin, ModelViewSet):      ... }      Alternatively, request users can be POSTed as a list of above objects, -    in which case multiple users will be created at once. +    in which case multiple users will be created at once. In this case, +    the response is an empty list.      #### Status codes      - 201: returned on success      - 400: if one of the given roles does not exist, or one of the given fields is invalid +    - 400: if multiple user objects with the same id are given      ### PUT /bot/users/<snowflake:int>      Update the user with the given `snowflake`. @@ -109,6 +160,34 @@ class UserViewSet(BulkCreateModelMixin, ModelViewSet):      - 400: if the request body was invalid, see response body for details      - 404: if the user with the given `snowflake` could not be found +    ### BULK PATCH /bot/users/bulk_patch +    Update users with the given `ids` and `details`. +    `id` field and at least one other field is mandatory. + +    #### Request body +    >>> [ +    ...     { +    ...         'id': int, +    ...         'name': str, +    ...         'discriminator': int, +    ...         'roles': List[int], +    ...         'in_guild': bool +    ...     }, +    ...     { +    ...         'id': int, +    ...         'name': str, +    ...         'discriminator': int, +    ...         'roles': List[int], +    ...         'in_guild': bool +    ...     }, +    ... ] + +    #### Status codes +    - 200: returned on success +    - 400: if the request body was invalid, see response body for details +    - 400: if multiple user objects with the same id are given +    - 404: if the user with the given id does not exist +      ### DELETE /bot/users/<snowflake:int>      Deletes the user with the given `snowflake`. @@ -118,4 +197,27 @@ class UserViewSet(BulkCreateModelMixin, ModelViewSet):      """      serializer_class = UserSerializer -    queryset = User.objects +    queryset = User.objects.all().order_by("id") +    pagination_class = UserListPagination + +    def get_serializer(self, *args, **kwargs) -> ModelSerializer: +        """Set Serializer many attribute to True if request body contains a list.""" +        if isinstance(kwargs.get('data', {}), list): +            kwargs['many'] = True + +        return super().get_serializer(*args, **kwargs) + +    @action(detail=False, methods=["PATCH"], name='user-bulk-patch') +    def bulk_patch(self, request: Request) -> Response: +        """Update multiple User objects in a single request.""" +        serializer = self.get_serializer( +            instance=self.get_queryset(), +            data=request.data, +            many=True, +            partial=True +        ) + +        serializer.is_valid(raise_exception=True) +        serializer.save() + +        return Response(serializer.data, status=status.HTTP_200_OK) | 
