aboutsummaryrefslogtreecommitdiffstats
path: root/pydis_site/apps
diff options
context:
space:
mode:
authorGravatar rohanjnr <[email protected]>2020-08-26 21:59:46 +0530
committerGravatar rohanjnr <[email protected]>2020-08-26 21:59:46 +0530
commit8e636a54b449f44f5bff56577a05d9a6a2dd72c0 (patch)
treeb69d8cab36ccf49281c47447839de7e2efe1245f /pydis_site/apps
parentadd pagination for GET request on /bot/users endpoint (diff)
add support for bulk updates on user model
implemented a method to handle bulk updates on user model via a new endpoint: /bot/users/bulk_patch
Diffstat (limited to 'pydis_site/apps')
-rw-r--r--pydis_site/apps/api/serializers.py77
-rw-r--r--pydis_site/apps/api/viewsets/bot/user.py60
2 files changed, 136 insertions, 1 deletions
diff --git a/pydis_site/apps/api/serializers.py b/pydis_site/apps/api/serializers.py
index 52e0d972..757faeae 100644
--- a/pydis_site/apps/api/serializers.py
+++ b/pydis_site/apps/api/serializers.py
@@ -1,5 +1,13 @@
"""Converters from Django models to data interchange formats and back."""
-from rest_framework.serializers import ModelSerializer, PrimaryKeyRelatedField, ValidationError
+from django.db.models.query import QuerySet
+from rest_framework.serializers import (
+ ListSerializer,
+ ModelSerializer,
+ PrimaryKeyRelatedField,
+ ValidationError
+)
+from rest_framework.settings import api_settings
+from rest_framework.utils import html
from rest_framework.validators import UniqueTogetherValidator
from rest_framework_bulk import BulkSerializerMixin
@@ -260,6 +268,72 @@ class TagSerializer(ModelSerializer):
fields = ('title', 'embed')
+class UserListSerializer(ListSerializer):
+ """List serializer for User model to handle bulk updates."""
+
+ def to_internal_value(self, data: list) -> list:
+ """
+ Overriding `to_internal_value` function with a few changes to support bulk updates.
+
+ List of dicts of native values <- List of dicts of primitive datatypes.
+ """
+ if html.is_html_input(data):
+ data = html.parse_html_list(data, default=[])
+
+ if not isinstance(data, list):
+ message = self.error_messages['not_a_list'].format(
+ input_type=type(data).__name__
+ )
+ raise ValidationError({
+ api_settings.NON_FIELD_ERRORS_KEY: [message]
+ }, code='not_a_list')
+
+ if not self.allow_empty and len(data) == 0:
+ message = self.error_messages['empty']
+ raise ValidationError({
+ api_settings.NON_FIELD_ERRORS_KEY: [message]
+ }, code='empty')
+
+ ret = []
+ errors = []
+
+ for item in data:
+ # inserted code
+ # bug: https://github.com/miki725/django-rest-framework-bulk/issues/68
+ # -----------------
+ try:
+ self.child.instance = self.instance.get(id=item['id'])
+ except User.DoesNotExist:
+ self.child.instance = None
+ # -----------------
+ self.child.initial_data = item
+ try:
+ validated = self.child.run_validation(item)
+ except ValidationError as exc:
+ errors.append(exc.detail)
+ else:
+ ret.append(validated)
+ errors.append({})
+
+ if any(errors):
+ raise ValidationError(errors)
+
+ return ret
+
+ def update(self, instance: QuerySet, validated_data: list) -> list:
+ """Override update method to support bulk updates."""
+ instance_mapping = {user.id: user for user in instance}
+ data_mapping = {item['id']: item for item in validated_data}
+
+ updated = []
+ for book_id, data in data_mapping.items():
+ book = instance_mapping.get(book_id, None)
+ if book is not None:
+ updated.append(self.child.update(book, data))
+
+ return updated
+
+
class UserSerializer(BulkSerializerMixin, ModelSerializer):
"""A class providing (de-)serialization of `User` instances."""
@@ -269,6 +343,7 @@ class UserSerializer(BulkSerializerMixin, ModelSerializer):
model = User
fields = ('id', 'name', 'discriminator', 'roles', 'in_guild')
depth = 1
+ list_serializer_class = UserListSerializer
class NominationSerializer(ModelSerializer):
diff --git a/pydis_site/apps/api/viewsets/bot/user.py b/pydis_site/apps/api/viewsets/bot/user.py
index b016bb66..d64ca113 100644
--- a/pydis_site/apps/api/viewsets/bot/user.py
+++ b/pydis_site/apps/api/viewsets/bot/user.py
@@ -1,4 +1,8 @@
+from rest_framework import status
+from rest_framework.decorators import action
from rest_framework.pagination import PageNumberPagination
+from rest_framework.request import Request
+from rest_framework.response import Response
from rest_framework.viewsets import ModelViewSet
from rest_framework_bulk import BulkCreateModelMixin
@@ -137,3 +141,59 @@ class UserViewSet(BulkCreateModelMixin, ModelViewSet):
serializer_class = UserSerializer
queryset = User.objects.all()
pagination_class = UserListPagination
+
+ @action(detail=False, methods=["PATCH"])
+ def bulk_patch(self, request: Request) -> Response:
+ """
+ Update multiple User objects in a single request.
+
+ ## Route
+ ### PATCH /bot/users/bulk_patch
+ Update all users with the IDs.
+ `id` field is mandatory, rest are optional.
+
+ #### Request body
+ >>> [
+ ... {
+ ... 'id': int,
+ ... 'name': str,
+ ... 'discriminator': int,
+ ... 'roles': List[int],
+ ... 'in_guild': bool
+ ... },
+ ... {
+ ... 'id': int,
+ ... 'name': str,
+ ... 'discriminator': int,
+ ... 'roles': List[int],
+ ... 'in_guild': bool
+ ... },
+ ... ]
+
+ #### Status codes
+ - 200: Returned on success.
+ - 400: if the request body was invalid, see response body for details.
+ """
+ queryset = self.get_queryset()
+ try:
+ object_ids = [item["id"] for item in request.data]
+ except KeyError:
+ # user ID not provided in request body.
+ resp = {
+ "Error": "User ID not provided."
+ }
+ return Response(resp, status=status.HTTP_400_BAD_REQUEST)
+
+ filtered_instances = queryset.filter(id__in=object_ids)
+
+ serializer = self.get_serializer(
+ instance=filtered_instances,
+ data=request.data,
+ many=True,
+ partial=True
+ )
+
+ if serializer.is_valid():
+ serializer.save()
+ return Response(serializer.data, status=status.HTTP_200_OK)
+ return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)