diff options
Diffstat (limited to 'pydis_site/apps/api/viewsets')
| -rw-r--r-- | pydis_site/apps/api/viewsets/bot/off_topic_channel_name.py | 28 | 
1 files changed, 16 insertions, 12 deletions
| diff --git a/pydis_site/apps/api/viewsets/bot/off_topic_channel_name.py b/pydis_site/apps/api/viewsets/bot/off_topic_channel_name.py index 922e6555..78f8c340 100644 --- a/pydis_site/apps/api/viewsets/bot/off_topic_channel_name.py +++ b/pydis_site/apps/api/viewsets/bot/off_topic_channel_name.py @@ -1,18 +1,17 @@  from django.db.models import Case, Value, When  from django.db.models.query import QuerySet -from django.http.request import HttpRequest  from django.shortcuts import get_object_or_404  from rest_framework.exceptions import ParseError -from rest_framework.mixins import DestroyModelMixin +from rest_framework.request import Request  from rest_framework.response import Response  from rest_framework.status import HTTP_201_CREATED -from rest_framework.viewsets import ViewSet +from rest_framework.viewsets import ModelViewSet  from pydis_site.apps.api.models.bot.off_topic_channel_name import OffTopicChannelName  from pydis_site.apps.api.serializers import OffTopicChannelNameSerializer -class OffTopicChannelNameViewSet(DestroyModelMixin, ViewSet): +class OffTopicChannelNameViewSet(ModelViewSet):      """      View of off-topic channel names used by the bot to rotate our off-topic names on a daily basis. @@ -58,6 +57,7 @@ class OffTopicChannelNameViewSet(DestroyModelMixin, ViewSet):      lookup_field = 'name'      serializer_class = OffTopicChannelNameSerializer +    queryset = OffTopicChannelName.objects.all()      def get_object(self) -> OffTopicChannelName:          """ @@ -65,15 +65,14 @@ class OffTopicChannelNameViewSet(DestroyModelMixin, ViewSet):          If it doesn't, a HTTP 404 is returned by way of throwing an exception.          """ -        queryset = self.get_queryset()          name = self.kwargs[self.lookup_field] -        return get_object_or_404(queryset, name=name) +        return get_object_or_404(self.queryset, name=name)      def get_queryset(self) -> QuerySet:          """Returns a queryset that covers the entire OffTopicChannelName table."""          return OffTopicChannelName.objects.all() -    def create(self, request: HttpRequest) -> Response: +    def create(self, request: Request, *args, **kwargs) -> Response:          """          DRF method for creating a new OffTopicChannelName. @@ -91,7 +90,7 @@ class OffTopicChannelNameViewSet(DestroyModelMixin, ViewSet):                  'name': ["This query parameter is required."]              }) -    def list(self, request: HttpRequest) -> Response: +    def list(self, request: Request, *args, **kwargs) -> Response:          """          DRF method for listing OffTopicChannelName entries. @@ -109,13 +108,13 @@ class OffTopicChannelNameViewSet(DestroyModelMixin, ViewSet):                      'random_items': ["Must be a positive integer."]                  }) -            queryset = self.get_queryset().order_by('used', '?')[:random_count] +            queryset = self.queryset.order_by('used', '?')[:random_count]              # When any name is used in our listing then this means we reached end of round              # and we need to reset all other names `used` to False              if any(offtopic_name.used for offtopic_name in queryset):                  # These names that we just got have to be excluded from updating used to False -                self.get_queryset().update( +                self.queryset.update(                      used=Case(                          When(                              name__in=(offtopic_name.name for offtopic_name in queryset), @@ -126,13 +125,18 @@ class OffTopicChannelNameViewSet(DestroyModelMixin, ViewSet):                  )              else:                  # Otherwise mark selected names `used` to True -                self.get_queryset().filter( +                self.queryset.filter(                      name__in=(offtopic_name.name for offtopic_name in queryset)                  ).update(used=True)              serialized = self.serializer_class(queryset, many=True)              return Response(serialized.data) -        queryset = self.get_queryset() +        params = {} + +        if active_param := request.query_params.get("active"): +            params["active"] = active_param.lower() == "true" + +        queryset = self.queryset.filter(**params)          serialized = self.serializer_class(queryset, many=True)          return Response(serialized.data) | 
