aboutsummaryrefslogtreecommitdiffstats
path: root/pydis_site
diff options
context:
space:
mode:
Diffstat (limited to 'pydis_site')
-rw-r--r--pydis_site/apps/api/serializers.py54
-rw-r--r--pydis_site/apps/api/tests/test_users.py116
-rw-r--r--pydis_site/apps/api/viewsets/bot/user.py108
3 files changed, 267 insertions, 11 deletions
diff --git a/pydis_site/apps/api/serializers.py b/pydis_site/apps/api/serializers.py
index 90bd6f91..1f24d29f 100644
--- a/pydis_site/apps/api/serializers.py
+++ b/pydis_site/apps/api/serializers.py
@@ -1,7 +1,13 @@
"""Converters from Django models to data interchange formats and back."""
-from rest_framework.serializers import ModelSerializer, PrimaryKeyRelatedField, ValidationError
+from django.db.models.query import QuerySet
+from rest_framework.serializers import (
+ IntegerField,
+ ListSerializer,
+ ModelSerializer,
+ PrimaryKeyRelatedField,
+ ValidationError
+)
from rest_framework.validators import UniqueTogetherValidator
-from rest_framework_bulk import BulkSerializerMixin
from .models import (
BotSetting,
@@ -249,15 +255,57 @@ class RoleSerializer(ModelSerializer):
fields = ('id', 'name', 'colour', 'permissions', 'position')
-class UserSerializer(BulkSerializerMixin, ModelSerializer):
+class UserListSerializer(ListSerializer):
+ """List serializer for User model to handle bulk updates."""
+
+ def create(self, validated_data: list) -> list:
+ """Override create method to optimize django queries."""
+ present_users = User.objects.all()
+ new_users = []
+ present_user_ids = [user.id for user in present_users]
+
+ for user_dict in validated_data:
+ if user_dict["id"] in present_user_ids:
+ raise ValidationError({"id": "User already exists."})
+ new_users.append(User(**user_dict))
+
+ return User.objects.bulk_create(new_users)
+
+ def update(self, instance: QuerySet, validated_data: list) -> list:
+ """
+ Override update method to support bulk updates.
+
+ ref:https://www.django-rest-framework.org/api-guide/serializers/#customizing-multiple-update
+ """
+ instance_mapping = {user.id: user for user in instance}
+ data_mapping = {item['id']: item for item in validated_data}
+
+ updated = []
+ fields_to_update = set()
+ for user_id, data in data_mapping.items():
+ for key in data:
+ fields_to_update.add(key)
+ user = instance_mapping.get(user_id)
+ user.__dict__.update(data)
+ updated.append(user)
+
+ fields_to_update.remove("id")
+ User.objects.bulk_update(updated, fields_to_update)
+ return updated
+
+
+class UserSerializer(ModelSerializer):
"""A class providing (de-)serialization of `User` instances."""
+ id = IntegerField(min_value=0)
+
class Meta:
"""Metadata defined for the Django REST Framework."""
model = User
fields = ('id', 'name', 'discriminator', 'roles', 'in_guild')
depth = 1
+ list_serializer_class = UserListSerializer
class NominationSerializer(ModelSerializer):
diff --git a/pydis_site/apps/api/tests/test_users.py b/pydis_site/apps/api/tests/test_users.py
index a02fce8a..1f9bd687 100644
--- a/pydis_site/apps/api/tests/test_users.py
+++ b/pydis_site/apps/api/tests/test_users.py
@@ -45,6 +45,13 @@ class CreationTests(APISubdomainTestCase):
position=1
)
+ cls.user = User.objects.create(
+ id=11,
+ name="Name doesn't matter.",
+ discriminator=1122,
+ in_guild=True
+ )
+
def test_accepts_valid_data(self):
url = reverse('bot:user-list', host='api')
data = {
@@ -115,6 +122,115 @@ class CreationTests(APISubdomainTestCase):
response = self.client.post(url, data=data)
self.assertEqual(response.status_code, 400)
+ def test_returns_400_for_user_recreation(self):
+ """Return 400 if User is already present in database."""
+ url = reverse('bot:user-list', host='api')
+ data = [{
+ 'id': 11,
+ 'name': 'You saw nothing.',
+ 'discriminator': 112,
+ 'in_guild': True
+ }]
+ response = self.client.post(url, data=data)
+ self.assertEqual(response.status_code, 400)
+
+
+class MultiPatchTests(APISubdomainTestCase):
+ @classmethod
+ def setUpTestData(cls):
+ cls.role_developer = Role.objects.create(
+ id=159,
+ name="Developer",
+ colour=2,
+ permissions=0b01010010101,
+ position=10,
+ )
+ cls.user_1 = User.objects.create(
+ id=1,
+ name="Patch test user 1.",
+ discriminator=1111,
+ in_guild=True
+ )
+ cls.user_2 = User.objects.create(
+ id=2,
+ name="Patch test user 2.",
+ discriminator=2222,
+ in_guild=True
+ )
+
+ def test_multiple_users_patch(self):
+ url = reverse("bot:user-bulk-patch", host="api")
+ data = [
+ {
+ "id": 1,
+ "name": "User 1 patched!",
+ "discriminator": 1010,
+ "roles": [self.role_developer.id],
+ "in_guild": False
+ },
+ {
+ "id": 2,
+ "name": "User 2 patched!"
+ }
+ ]
+
+ response = self.client.patch(url, data=data)
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.json()[0], data[0])
+
+ user_2 = User.objects.get(id=2)
+ self.assertEqual(user_2.name, data[1]["name"])
+
+ def test_returns_400_for_missing_user_id(self):
+ url = reverse("bot:user-bulk-patch", host="api")
+ data = [
+ {
+ "name": "I am ghost user!",
+ "discriminator": 1010,
+ "roles": [self.role_developer.id],
+ "in_guild": False
+ },
+ {
+ "name": "patch me? whats my id?"
+ }
+ ]
+ response = self.client.patch(url, data=data)
+ self.assertEqual(response.status_code, 400)
+
+ def test_returns_404_for_not_found_user(self):
+ url = reverse("bot:user-bulk-patch", host="api")
+ data = [
+ {
+ "id": 1,
+ "name": "User 1 patched again!!!",
+ "discriminator": 1010,
+ "roles": [self.role_developer.id],
+ "in_guild": False
+ },
+ {
+ "id": 22503405,
+ "name": "User unknown not patched!"
+ }
+ ]
+ response = self.client.patch(url, data=data)
+ self.assertEqual(response.status_code, 404)
+
+ def test_returns_400_for_bad_data(self):
+ url = reverse("bot:user-bulk-patch", host="api")
+ data = [
+ {
+ "id": 1,
+ "in_guild": "Catch me!"
+ },
+ {
+ "id": 2,
+ "discriminator": "find me!"
+ }
+ ]
+
+ response = self.client.patch(url, data=data)
+ self.assertEqual(response.status_code, 400)
+
class UserModelTests(APISubdomainTestCase):
@classmethod
diff --git a/pydis_site/apps/api/viewsets/bot/user.py b/pydis_site/apps/api/viewsets/bot/user.py
index 9571b3d7..d015fe71 100644
--- a/pydis_site/apps/api/viewsets/bot/user.py
+++ b/pydis_site/apps/api/viewsets/bot/user.py
@@ -1,21 +1,37 @@
+from rest_framework import status
+from rest_framework.decorators import action
+from rest_framework.pagination import PageNumberPagination
+from rest_framework.request import Request
+from rest_framework.response import Response
+from rest_framework.serializers import ModelSerializer
from rest_framework.viewsets import ModelViewSet
-from rest_framework_bulk import BulkCreateModelMixin
from pydis_site.apps.api.models.bot.user import User
from pydis_site.apps.api.serializers import UserSerializer
-class UserViewSet(BulkCreateModelMixin, ModelViewSet):
+class UserListPagination(PageNumberPagination):
+ """Custom pagination class for the User Model."""
+
+ page_size = 10000
+ page_size_query_param = "page_size"
+
+
+class UserViewSet(ModelViewSet):
"""
View providing CRUD operations on Discord users through the bot.
## Routes
### GET /bot/users
- Returns all users currently known.
+ Returns all users currently known with pagination.
#### Response format
- >>> [
- ... {
+ >>> {
+ ... 'count': 95000,
+ ... 'next': "http://api.pythondiscord.com/bot/users?page=2",
+ ... 'previous': None,
+ ... 'results': [
+ ... {
... 'id': 409107086526644234,
... 'name': "Python",
... 'discriminator': 4329,
@@ -26,8 +42,13 @@ class UserViewSet(BulkCreateModelMixin, ModelViewSet):
... 458226699344019457
... ],
... 'in_guild': True
- ... }
- ... ]
+ ... },
+ ... ]
+ ... }
+
+ #### Query Parameters
+ - page_size: Number of Users in one page.
+ - page: Page number
#### Status codes
- 200: returned on success
@@ -118,4 +139,75 @@ class UserViewSet(BulkCreateModelMixin, ModelViewSet):
"""
serializer_class = UserSerializer
- queryset = User.objects
+ queryset = User.objects.all()
+ pagination_class = UserListPagination
+
+ def get_serializer(self, *args, **kwargs) -> ModelSerializer:
+ """Set Serializer many attribute to True if request body contains a list."""
+ if isinstance(kwargs.get('data', {}), list):
+ kwargs['many'] = True
+
+ return super().get_serializer(*args, **kwargs)
+
+ @action(detail=False, methods=["PATCH"], name='user-bulk-patch')
+ def bulk_patch(self, request: Request) -> Response:
+ """
+ Update multiple User objects in a single request.
+
+ ## Route
+ ### PATCH /bot/users/bulk_patch
+ Update all users with the IDs.
+ `id` field is mandatory, rest are optional.
+
+ #### Request body
+ >>> [
+ ... {
+ ... 'id': int,
+ ... 'name': str,
+ ... 'discriminator': int,
+ ... 'roles': List[int],
+ ... 'in_guild': bool
+ ... },
+ ... {
+ ... 'id': int,
+ ... 'name': str,
+ ... 'discriminator': int,
+ ... 'roles': List[int],
+ ... 'in_guild': bool
+ ... },
+ ... ]
+
+ #### Status codes
+ - 200: Returned on success.
+ - 400: if the request body was invalid, see response body for details.
+ """
+ queryset = self.get_queryset()
+ try:
+ object_ids = [item["id"] for item in request.data]
+ except KeyError:
+ # user ID not provided in request body.
+ resp = {
+ "Error": "User ID not provided."
+ }
+ return Response(resp, status=status.HTTP_400_BAD_REQUEST)
+
+ filtered_instances = queryset.filter(id__in=object_ids)
+
+ if filtered_instances.count() != len(object_ids):
+ # If all user objects passed in request.body are not present in the database.
+ resp = {
+ "Error": "User object not found."
+ }
+ return Response(resp, status=status.HTTP_404_NOT_FOUND)
+
+ serializer = self.get_serializer(
+ instance=filtered_instances,
+ data=request.data,
+ many=True,
+ partial=True
+ )
+
+ if serializer.is_valid():
+ serializer.save()
+ return Response(serializer.data, status=status.HTTP_200_OK)
+ return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)