aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--pydis_site/apps/api/serializers.py17
-rw-r--r--pydis_site/apps/api/tests/test_users.py2
-rw-r--r--pydis_site/apps/api/viewsets/bot/user.py41
3 files changed, 42 insertions, 18 deletions
diff --git a/pydis_site/apps/api/serializers.py b/pydis_site/apps/api/serializers.py
index 1f24d29f..a560d491 100644
--- a/pydis_site/apps/api/serializers.py
+++ b/pydis_site/apps/api/serializers.py
@@ -278,16 +278,20 @@ class UserListSerializer(ListSerializer):
ref:https://www.django-rest-framework.org/api-guide/serializers/#customizing-multiple-update
"""
instance_mapping = {user.id: user for user in instance}
- data_mapping = {item['id']: item for item in validated_data}
updated = []
fields_to_update = set()
- for user_id, data in data_mapping.items():
- for key in data:
+ for user_data in validated_data:
+ for key in user_data:
fields_to_update.add(key)
- user = instance_mapping.get(user_id)
- user.__dict__.update(data)
- updated.append(user)
+
+ try:
+ user = instance_mapping[user_data["id"]]
+ except KeyError:
+ raise ValidationError({"id": f"User with id {user_data['id']} not found."})
+
+ user.__dict__.update(user_data)
+ updated.append(user)
fields_to_update.remove("id")
User.objects.bulk_update(updated, fields_to_update)
@@ -297,6 +301,7 @@ class UserListSerializer(ListSerializer):
class UserSerializer(ModelSerializer):
"""A class providing (de-)serialization of `User` instances."""
+ # ID field must be explicitly set as the default id field is read-only.
id = IntegerField(min_value=0)
class Meta:
diff --git a/pydis_site/apps/api/tests/test_users.py b/pydis_site/apps/api/tests/test_users.py
index 1f9bd687..affc2c48 100644
--- a/pydis_site/apps/api/tests/test_users.py
+++ b/pydis_site/apps/api/tests/test_users.py
@@ -213,7 +213,7 @@ class MultiPatchTests(APISubdomainTestCase):
}
]
response = self.client.patch(url, data=data)
- self.assertEqual(response.status_code, 404)
+ self.assertEqual(response.status_code, 400)
def test_returns_400_for_bad_data(self):
url = reverse("bot:user-bulk-patch", host="api")
diff --git a/pydis_site/apps/api/viewsets/bot/user.py b/pydis_site/apps/api/viewsets/bot/user.py
index d015fe71..0dd529be 100644
--- a/pydis_site/apps/api/viewsets/bot/user.py
+++ b/pydis_site/apps/api/viewsets/bot/user.py
@@ -1,3 +1,5 @@
+from collections import OrderedDict
+
from rest_framework import status
from rest_framework.decorators import action
from rest_framework.pagination import PageNumberPagination
@@ -16,6 +18,30 @@ class UserListPagination(PageNumberPagination):
page_size = 10000
page_size_query_param = "page_size"
+ def get_next_page_number(self) -> int:
+ """Get the next page number."""
+ if not self.page.has_next():
+ return None
+ page_number = self.page.next_page_number()
+ return page_number
+
+ def get_previous_page_number(self) -> int:
+ """Get the previous page number."""
+ if not self.page.has_previous():
+ return None
+
+ page_number = self.page.previous_page_number()
+ return page_number
+
+ def get_paginated_response(self, data: list) -> Response:
+ """Override method to send modified response."""
+ return Response(OrderedDict([
+ ('count', self.page.paginator.count),
+ ('next_page_no', self.get_next_page_number()),
+ ('previous_page_no', self.get_previous_page_number()),
+ ('results', data)
+ ]))
+
class UserViewSet(ModelViewSet):
"""
@@ -193,13 +219,6 @@ class UserViewSet(ModelViewSet):
filtered_instances = queryset.filter(id__in=object_ids)
- if filtered_instances.count() != len(object_ids):
- # If all user objects passed in request.body are not present in the database.
- resp = {
- "Error": "User object not found."
- }
- return Response(resp, status=status.HTTP_404_NOT_FOUND)
-
serializer = self.get_serializer(
instance=filtered_instances,
data=request.data,
@@ -207,7 +226,7 @@ class UserViewSet(ModelViewSet):
partial=True
)
- if serializer.is_valid():
- serializer.save()
- return Response(serializer.data, status=status.HTTP_200_OK)
- return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
+ serializer.is_valid(raise_exception=True)
+ serializer.save()
+
+ return Response(serializer.data, status=status.HTTP_200_OK)