diff options
| -rw-r--r-- | pydis_site/apps/api/serializers.py | 77 | ||||
| -rw-r--r-- | pydis_site/apps/api/viewsets/bot/user.py | 60 | 
2 files changed, 136 insertions, 1 deletions
diff --git a/pydis_site/apps/api/serializers.py b/pydis_site/apps/api/serializers.py index 52e0d972..757faeae 100644 --- a/pydis_site/apps/api/serializers.py +++ b/pydis_site/apps/api/serializers.py @@ -1,5 +1,13 @@  """Converters from Django models to data interchange formats and back.""" -from rest_framework.serializers import ModelSerializer, PrimaryKeyRelatedField, ValidationError +from django.db.models.query import QuerySet +from rest_framework.serializers import ( +    ListSerializer, +    ModelSerializer, +    PrimaryKeyRelatedField, +    ValidationError +) +from rest_framework.settings import api_settings +from rest_framework.utils import html  from rest_framework.validators import UniqueTogetherValidator  from rest_framework_bulk import BulkSerializerMixin @@ -260,6 +268,72 @@ class TagSerializer(ModelSerializer):          fields = ('title', 'embed') +class UserListSerializer(ListSerializer): +    """List serializer for User model to handle bulk updates.""" + +    def to_internal_value(self, data: list) -> list: +        """ +        Overriding `to_internal_value` function with a few changes to support bulk updates. + +        List of dicts of native values <- List of dicts of primitive datatypes. +        """ +        if html.is_html_input(data): +            data = html.parse_html_list(data, default=[]) + +        if not isinstance(data, list): +            message = self.error_messages['not_a_list'].format( +                input_type=type(data).__name__ +            ) +            raise ValidationError({ +                api_settings.NON_FIELD_ERRORS_KEY: [message] +            }, code='not_a_list') + +        if not self.allow_empty and len(data) == 0: +            message = self.error_messages['empty'] +            raise ValidationError({ +                api_settings.NON_FIELD_ERRORS_KEY: [message] +            }, code='empty') + +        ret = [] +        errors = [] + +        for item in data: +            # inserted code +            # bug: https://github.com/miki725/django-rest-framework-bulk/issues/68 +            # ----------------- +            try: +                self.child.instance = self.instance.get(id=item['id']) +            except User.DoesNotExist: +                self.child.instance = None +            # ----------------- +            self.child.initial_data = item +            try: +                validated = self.child.run_validation(item) +            except ValidationError as exc: +                errors.append(exc.detail) +            else: +                ret.append(validated) +                errors.append({}) + +        if any(errors): +            raise ValidationError(errors) + +        return ret + +    def update(self, instance: QuerySet, validated_data: list) -> list: +        """Override update method to support bulk updates.""" +        instance_mapping = {user.id: user for user in instance} +        data_mapping = {item['id']: item for item in validated_data} + +        updated = [] +        for book_id, data in data_mapping.items(): +            book = instance_mapping.get(book_id, None) +            if book is not None: +                updated.append(self.child.update(book, data)) + +        return updated + +  class UserSerializer(BulkSerializerMixin, ModelSerializer):      """A class providing (de-)serialization of `User` instances.""" @@ -269,6 +343,7 @@ class UserSerializer(BulkSerializerMixin, ModelSerializer):          model = User          fields = ('id', 'name', 'discriminator', 'roles', 'in_guild')          depth = 1 +        list_serializer_class = UserListSerializer  class NominationSerializer(ModelSerializer): diff --git a/pydis_site/apps/api/viewsets/bot/user.py b/pydis_site/apps/api/viewsets/bot/user.py index b016bb66..d64ca113 100644 --- a/pydis_site/apps/api/viewsets/bot/user.py +++ b/pydis_site/apps/api/viewsets/bot/user.py @@ -1,4 +1,8 @@ +from rest_framework import status +from rest_framework.decorators import action  from rest_framework.pagination import PageNumberPagination +from rest_framework.request import Request +from rest_framework.response import Response  from rest_framework.viewsets import ModelViewSet  from rest_framework_bulk import BulkCreateModelMixin @@ -137,3 +141,59 @@ class UserViewSet(BulkCreateModelMixin, ModelViewSet):      serializer_class = UserSerializer      queryset = User.objects.all()      pagination_class = UserListPagination + +    @action(detail=False, methods=["PATCH"]) +    def bulk_patch(self, request: Request) -> Response: +        """ +        Update multiple User objects in a single request. + +        ## Route +        ### PATCH /bot/users/bulk_patch +        Update all users with the IDs. +        `id` field is mandatory, rest are optional. + +        #### Request body +        >>> [ +        ...     { +        ...     'id': int, +        ...     'name': str, +        ...     'discriminator': int, +        ...     'roles': List[int], +        ...     'in_guild': bool +        ...     }, +        ...     { +        ...     'id': int, +        ...     'name': str, +        ...     'discriminator': int, +        ...     'roles': List[int], +        ...     'in_guild': bool +        ...     }, +        ... ] + +        #### Status codes +        - 200: Returned on success. +        - 400: if the request body was invalid, see response body for details. +        """ +        queryset = self.get_queryset() +        try: +            object_ids = [item["id"] for item in request.data] +        except KeyError: +            # user ID not provided in request body. +            resp = { +                "Error": "User ID not provided." +            } +            return Response(resp, status=status.HTTP_400_BAD_REQUEST) + +        filtered_instances = queryset.filter(id__in=object_ids) + +        serializer = self.get_serializer( +            instance=filtered_instances, +            data=request.data, +            many=True, +            partial=True +        ) + +        if serializer.is_valid(): +            serializer.save() +            return Response(serializer.data, status=status.HTTP_200_OK) +        return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)  |