diff options
Diffstat (limited to '')
| -rw-r--r-- | pydis_site/apps/api/serializers.py | 73 | ||||
| -rw-r--r-- | pydis_site/apps/api/viewsets/bot/user.py | 20 | 
2 files changed, 35 insertions, 58 deletions
| diff --git a/pydis_site/apps/api/serializers.py b/pydis_site/apps/api/serializers.py index 0589ce77..21c488a8 100644 --- a/pydis_site/apps/api/serializers.py +++ b/pydis_site/apps/api/serializers.py @@ -1,15 +1,13 @@  """Converters from Django models to data interchange formats and back."""  from django.db.models.query import QuerySet  from rest_framework.serializers import ( +    IntegerField,      ListSerializer,      ModelSerializer,      PrimaryKeyRelatedField,      ValidationError  ) -from rest_framework.settings import api_settings -from rest_framework.utils import html  from rest_framework.validators import UniqueTogetherValidator -from rest_framework_bulk import BulkSerializerMixin  from .models import (      BotSetting, @@ -271,55 +269,18 @@ class TagSerializer(ModelSerializer):  class UserListSerializer(ListSerializer):      """List serializer for User model to handle bulk updates.""" -    def to_internal_value(self, data: list) -> list: -        """ -        Overriding `to_internal_value` function with a few changes to support bulk updates. - -        ref: https://github.com/miki725/django-rest-framework-bulk/issues/68 +    def create(self, validated_data: list) -> list: +        """Override create method to optimize django queries.""" +        present_users = User.objects.all() +        new_users = [] +        present_user_ids = [user.id for user in present_users] -        List of dicts of native values <- List of dicts of primitive datatypes. -        """ -        if html.is_html_input(data): -            data = html.parse_html_list(data, default=[]) +        for user_dict in validated_data: +            if user_dict["id"] in present_user_ids: +                raise ValidationError({"id": "User already exists."}) +            new_users.append(User(**user_dict)) -        if not isinstance(data, list): -            message = self.error_messages['not_a_list'].format( -                input_type=type(data).__name__ -            ) -            raise ValidationError({ -                api_settings.NON_FIELD_ERRORS_KEY: [message] -            }, code='not_a_list') - -        if not self.allow_empty and len(data) == 0: -            message = self.error_messages['empty'] -            raise ValidationError({ -                api_settings.NON_FIELD_ERRORS_KEY: [message] -            }, code='empty') - -        ret = [] -        errors = [] - -        for item in data: -            # inserted code -            # ----------------- -            try: -                self.child.instance = self.instance.get(id=item['id']) -            except (User.DoesNotExist, AttributeError): -                self.child.instance = None -            # ----------------- -            self.child.initial_data = item -            try: -                validated = self.child.run_validation(item) -            except ValidationError as exc: -                errors.append(exc.detail) -            else: -                ret.append(validated) -                errors.append({}) - -        if any(errors): -            raise ValidationError(errors) - -        return ret +        return User.objects.bulk_create(new_users)      def update(self, instance: QuerySet, validated_data: list) -> list:          """ @@ -331,17 +292,19 @@ class UserListSerializer(ListSerializer):          data_mapping = {item['id']: item for item in validated_data}          updated = [] -        for book_id, data in data_mapping.items(): -            book = instance_mapping.get(book_id, None) -            if book is not None: -                updated.append(self.child.update(book, data)) +        for user_id, data in data_mapping.items(): +            user = instance_mapping.get(user_id, None) +            if user: +                updated.append(self.child.update(user, data))          return updated -class UserSerializer(BulkSerializerMixin, ModelSerializer): +class UserSerializer(ModelSerializer):      """A class providing (de-)serialization of `User` instances.""" +    id = IntegerField(min_value=0) +      class Meta:          """Metadata defined for the Django REST Framework.""" diff --git a/pydis_site/apps/api/viewsets/bot/user.py b/pydis_site/apps/api/viewsets/bot/user.py index d64ca113..d015fe71 100644 --- a/pydis_site/apps/api/viewsets/bot/user.py +++ b/pydis_site/apps/api/viewsets/bot/user.py @@ -3,8 +3,8 @@ from rest_framework.decorators import action  from rest_framework.pagination import PageNumberPagination  from rest_framework.request import Request  from rest_framework.response import Response +from rest_framework.serializers import ModelSerializer  from rest_framework.viewsets import ModelViewSet -from rest_framework_bulk import BulkCreateModelMixin  from pydis_site.apps.api.models.bot.user import User  from pydis_site.apps.api.serializers import UserSerializer @@ -17,7 +17,7 @@ class UserListPagination(PageNumberPagination):      page_size_query_param = "page_size" -class UserViewSet(BulkCreateModelMixin, ModelViewSet): +class UserViewSet(ModelViewSet):      """      View providing CRUD operations on Discord users through the bot. @@ -142,7 +142,14 @@ class UserViewSet(BulkCreateModelMixin, ModelViewSet):      queryset = User.objects.all()      pagination_class = UserListPagination -    @action(detail=False, methods=["PATCH"]) +    def get_serializer(self, *args, **kwargs) -> ModelSerializer: +        """Set Serializer many attribute to True if request body contains a list.""" +        if isinstance(kwargs.get('data', {}), list): +            kwargs['many'] = True + +        return super().get_serializer(*args, **kwargs) + +    @action(detail=False, methods=["PATCH"], name='user-bulk-patch')      def bulk_patch(self, request: Request) -> Response:          """          Update multiple User objects in a single request. @@ -186,6 +193,13 @@ class UserViewSet(BulkCreateModelMixin, ModelViewSet):          filtered_instances = queryset.filter(id__in=object_ids) +        if filtered_instances.count() != len(object_ids): +            # If all user objects passed in request.body are not present in the database. +            resp = { +                "Error": "User object not found." +            } +            return Response(resp, status=status.HTTP_404_NOT_FOUND) +          serializer = self.get_serializer(              instance=filtered_instances,              data=request.data, | 
