aboutsummaryrefslogtreecommitdiffstats
path: root/pydis_site/apps
diff options
context:
space:
mode:
authorGravatar rohanjnr <[email protected]>2020-08-28 12:58:55 +0530
committerGravatar rohanjnr <[email protected]>2020-08-28 12:58:55 +0530
commit567f7f0c4a71ace555c9b3123ef50d6ae47756cd (patch)
tree00cc320c5e4096eddcf9ac92b0dfd35321e5f48c /pydis_site/apps
parentExcept AttributeError when self.instance is None and while fetching User obje... (diff)
Add code to replace restframework_bulk package for bulk create and simplify UserListSerializer
`to_internal_value()` function is not longer overriden in UserListSerializer, this is due to explicitly stating the `id` field in UserSerializer as mentioned in the documentation. Override `create()` method in UserListSerializer and override `get_serializer()` method in `UserViewSet` to support bulk creation.
Diffstat (limited to 'pydis_site/apps')
-rw-r--r--pydis_site/apps/api/serializers.py73
-rw-r--r--pydis_site/apps/api/viewsets/bot/user.py20
2 files changed, 35 insertions, 58 deletions
diff --git a/pydis_site/apps/api/serializers.py b/pydis_site/apps/api/serializers.py
index 0589ce77..21c488a8 100644
--- a/pydis_site/apps/api/serializers.py
+++ b/pydis_site/apps/api/serializers.py
@@ -1,15 +1,13 @@
"""Converters from Django models to data interchange formats and back."""
from django.db.models.query import QuerySet
from rest_framework.serializers import (
+ IntegerField,
ListSerializer,
ModelSerializer,
PrimaryKeyRelatedField,
ValidationError
)
-from rest_framework.settings import api_settings
-from rest_framework.utils import html
from rest_framework.validators import UniqueTogetherValidator
-from rest_framework_bulk import BulkSerializerMixin
from .models import (
BotSetting,
@@ -271,55 +269,18 @@ class TagSerializer(ModelSerializer):
class UserListSerializer(ListSerializer):
"""List serializer for User model to handle bulk updates."""
- def to_internal_value(self, data: list) -> list:
- """
- Overriding `to_internal_value` function with a few changes to support bulk updates.
-
- ref: https://github.com/miki725/django-rest-framework-bulk/issues/68
+ def create(self, validated_data: list) -> list:
+ """Override create method to optimize django queries."""
+ present_users = User.objects.all()
+ new_users = []
+ present_user_ids = [user.id for user in present_users]
- List of dicts of native values <- List of dicts of primitive datatypes.
- """
- if html.is_html_input(data):
- data = html.parse_html_list(data, default=[])
+ for user_dict in validated_data:
+ if user_dict["id"] in present_user_ids:
+ raise ValidationError({"id": "User already exists."})
+ new_users.append(User(**user_dict))
- if not isinstance(data, list):
- message = self.error_messages['not_a_list'].format(
- input_type=type(data).__name__
- )
- raise ValidationError({
- api_settings.NON_FIELD_ERRORS_KEY: [message]
- }, code='not_a_list')
-
- if not self.allow_empty and len(data) == 0:
- message = self.error_messages['empty']
- raise ValidationError({
- api_settings.NON_FIELD_ERRORS_KEY: [message]
- }, code='empty')
-
- ret = []
- errors = []
-
- for item in data:
- # inserted code
- # -----------------
- try:
- self.child.instance = self.instance.get(id=item['id'])
- except (User.DoesNotExist, AttributeError):
- self.child.instance = None
- # -----------------
- self.child.initial_data = item
- try:
- validated = self.child.run_validation(item)
- except ValidationError as exc:
- errors.append(exc.detail)
- else:
- ret.append(validated)
- errors.append({})
-
- if any(errors):
- raise ValidationError(errors)
-
- return ret
+ return User.objects.bulk_create(new_users)
def update(self, instance: QuerySet, validated_data: list) -> list:
"""
@@ -331,17 +292,19 @@ class UserListSerializer(ListSerializer):
data_mapping = {item['id']: item for item in validated_data}
updated = []
- for book_id, data in data_mapping.items():
- book = instance_mapping.get(book_id, None)
- if book is not None:
- updated.append(self.child.update(book, data))
+ for user_id, data in data_mapping.items():
+ user = instance_mapping.get(user_id, None)
+ if user:
+ updated.append(self.child.update(user, data))
return updated
-class UserSerializer(BulkSerializerMixin, ModelSerializer):
+class UserSerializer(ModelSerializer):
"""A class providing (de-)serialization of `User` instances."""
+ id = IntegerField(min_value=0)
+
class Meta:
"""Metadata defined for the Django REST Framework."""
diff --git a/pydis_site/apps/api/viewsets/bot/user.py b/pydis_site/apps/api/viewsets/bot/user.py
index d64ca113..d015fe71 100644
--- a/pydis_site/apps/api/viewsets/bot/user.py
+++ b/pydis_site/apps/api/viewsets/bot/user.py
@@ -3,8 +3,8 @@ from rest_framework.decorators import action
from rest_framework.pagination import PageNumberPagination
from rest_framework.request import Request
from rest_framework.response import Response
+from rest_framework.serializers import ModelSerializer
from rest_framework.viewsets import ModelViewSet
-from rest_framework_bulk import BulkCreateModelMixin
from pydis_site.apps.api.models.bot.user import User
from pydis_site.apps.api.serializers import UserSerializer
@@ -17,7 +17,7 @@ class UserListPagination(PageNumberPagination):
page_size_query_param = "page_size"
-class UserViewSet(BulkCreateModelMixin, ModelViewSet):
+class UserViewSet(ModelViewSet):
"""
View providing CRUD operations on Discord users through the bot.
@@ -142,7 +142,14 @@ class UserViewSet(BulkCreateModelMixin, ModelViewSet):
queryset = User.objects.all()
pagination_class = UserListPagination
- @action(detail=False, methods=["PATCH"])
+ def get_serializer(self, *args, **kwargs) -> ModelSerializer:
+ """Set Serializer many attribute to True if request body contains a list."""
+ if isinstance(kwargs.get('data', {}), list):
+ kwargs['many'] = True
+
+ return super().get_serializer(*args, **kwargs)
+
+ @action(detail=False, methods=["PATCH"], name='user-bulk-patch')
def bulk_patch(self, request: Request) -> Response:
"""
Update multiple User objects in a single request.
@@ -186,6 +193,13 @@ class UserViewSet(BulkCreateModelMixin, ModelViewSet):
filtered_instances = queryset.filter(id__in=object_ids)
+ if filtered_instances.count() != len(object_ids):
+ # If all user objects passed in request.body are not present in the database.
+ resp = {
+ "Error": "User object not found."
+ }
+ return Response(resp, status=status.HTTP_404_NOT_FOUND)
+
serializer = self.get_serializer(
instance=filtered_instances,
data=request.data,