diff options
-rw-r--r-- | pydis_site/apps/api/serializers.py | 17 | ||||
-rw-r--r-- | pydis_site/apps/api/tests/test_users.py | 2 | ||||
-rw-r--r-- | pydis_site/apps/api/viewsets/bot/user.py | 41 |
3 files changed, 42 insertions, 18 deletions
diff --git a/pydis_site/apps/api/serializers.py b/pydis_site/apps/api/serializers.py index 1f24d29f..a560d491 100644 --- a/pydis_site/apps/api/serializers.py +++ b/pydis_site/apps/api/serializers.py @@ -278,16 +278,20 @@ class UserListSerializer(ListSerializer): ref:https://www.django-rest-framework.org/api-guide/serializers/#customizing-multiple-update """ instance_mapping = {user.id: user for user in instance} - data_mapping = {item['id']: item for item in validated_data} updated = [] fields_to_update = set() - for user_id, data in data_mapping.items(): - for key in data: + for user_data in validated_data: + for key in user_data: fields_to_update.add(key) - user = instance_mapping.get(user_id) - user.__dict__.update(data) - updated.append(user) + + try: + user = instance_mapping[user_data["id"]] + except KeyError: + raise ValidationError({"id": f"User with id {user_data['id']} not found."}) + + user.__dict__.update(user_data) + updated.append(user) fields_to_update.remove("id") User.objects.bulk_update(updated, fields_to_update) @@ -297,6 +301,7 @@ class UserListSerializer(ListSerializer): class UserSerializer(ModelSerializer): """A class providing (de-)serialization of `User` instances.""" + # ID field must be explicitly set as the default id field is read-only. id = IntegerField(min_value=0) class Meta: diff --git a/pydis_site/apps/api/tests/test_users.py b/pydis_site/apps/api/tests/test_users.py index 1f9bd687..affc2c48 100644 --- a/pydis_site/apps/api/tests/test_users.py +++ b/pydis_site/apps/api/tests/test_users.py @@ -213,7 +213,7 @@ class MultiPatchTests(APISubdomainTestCase): } ] response = self.client.patch(url, data=data) - self.assertEqual(response.status_code, 404) + self.assertEqual(response.status_code, 400) def test_returns_400_for_bad_data(self): url = reverse("bot:user-bulk-patch", host="api") diff --git a/pydis_site/apps/api/viewsets/bot/user.py b/pydis_site/apps/api/viewsets/bot/user.py index d015fe71..0dd529be 100644 --- a/pydis_site/apps/api/viewsets/bot/user.py +++ b/pydis_site/apps/api/viewsets/bot/user.py @@ -1,3 +1,5 @@ +from collections import OrderedDict + from rest_framework import status from rest_framework.decorators import action from rest_framework.pagination import PageNumberPagination @@ -16,6 +18,30 @@ class UserListPagination(PageNumberPagination): page_size = 10000 page_size_query_param = "page_size" + def get_next_page_number(self) -> int: + """Get the next page number.""" + if not self.page.has_next(): + return None + page_number = self.page.next_page_number() + return page_number + + def get_previous_page_number(self) -> int: + """Get the previous page number.""" + if not self.page.has_previous(): + return None + + page_number = self.page.previous_page_number() + return page_number + + def get_paginated_response(self, data: list) -> Response: + """Override method to send modified response.""" + return Response(OrderedDict([ + ('count', self.page.paginator.count), + ('next_page_no', self.get_next_page_number()), + ('previous_page_no', self.get_previous_page_number()), + ('results', data) + ])) + class UserViewSet(ModelViewSet): """ @@ -193,13 +219,6 @@ class UserViewSet(ModelViewSet): filtered_instances = queryset.filter(id__in=object_ids) - if filtered_instances.count() != len(object_ids): - # If all user objects passed in request.body are not present in the database. - resp = { - "Error": "User object not found." - } - return Response(resp, status=status.HTTP_404_NOT_FOUND) - serializer = self.get_serializer( instance=filtered_instances, data=request.data, @@ -207,7 +226,7 @@ class UserViewSet(ModelViewSet): partial=True ) - if serializer.is_valid(): - serializer.save() - return Response(serializer.data, status=status.HTTP_200_OK) - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + serializer.is_valid(raise_exception=True) + serializer.save() + + return Response(serializer.data, status=status.HTTP_200_OK) |