aboutsummaryrefslogtreecommitdiffstats
path: root/pydis_site/apps
diff options
context:
space:
mode:
Diffstat (limited to 'pydis_site/apps')
-rw-r--r--pydis_site/apps/api/serializers.py11
-rw-r--r--pydis_site/apps/api/tests/test_off_topic_channel_names.py38
-rw-r--r--pydis_site/apps/api/viewsets/bot/off_topic_channel_name.py19
3 files changed, 54 insertions, 14 deletions
diff --git a/pydis_site/apps/api/serializers.py b/pydis_site/apps/api/serializers.py
index e09c383e..0d505675 100644
--- a/pydis_site/apps/api/serializers.py
+++ b/pydis_site/apps/api/serializers.py
@@ -1,4 +1,6 @@
"""Converters from Django models to data interchange formats and back."""
+from typing import List
+
from django.db.models.query import QuerySet
from django.db.utils import IntegrityError
from rest_framework.exceptions import NotFound
@@ -201,18 +203,18 @@ class ExpandedInfractionSerializer(InfractionSerializer):
class OffTopicChannelNameListSerializer(ListSerializer):
- def update(self, instance, validated_data):
- pass
+ """Custom ListSerializer to override to_representation() when list views are triggered."""
- def to_representation(self, obj: OffTopicChannelName) -> str:
+ def to_representation(self, objects: List[OffTopicChannelName]) -> List[str]:
"""
Return the representation of this `OffTopicChannelName`.
+
This only returns the name of the off topic channel name. As the model
only has a single attribute, it is unnecessary to create a nested dictionary.
Additionally, this allows off topic channel name routes to simply return an
array of names instead of objects, saving on bandwidth.
"""
- return obj.name
+ return [obj.name for obj in objects]
class OffTopicChannelNameSerializer(ModelSerializer):
@@ -220,6 +222,7 @@ class OffTopicChannelNameSerializer(ModelSerializer):
class Meta:
"""Metadata defined for the Django REST Framework."""
+
list_serializer_class = OffTopicChannelNameListSerializer
model = OffTopicChannelName
fields = ('name', 'used', 'active')
diff --git a/pydis_site/apps/api/tests/test_off_topic_channel_names.py b/pydis_site/apps/api/tests/test_off_topic_channel_names.py
index 3ab8b22d..a407654c 100644
--- a/pydis_site/apps/api/tests/test_off_topic_channel_names.py
+++ b/pydis_site/apps/api/tests/test_off_topic_channel_names.py
@@ -65,8 +65,15 @@ class EmptyDatabaseTests(APISubdomainTestCase):
class ListTests(APISubdomainTestCase):
@classmethod
def setUpTestData(cls):
- cls.test_name = OffTopicChannelName.objects.create(name='lemons-lemonade-stand', used=False)
- cls.test_name_2 = OffTopicChannelName.objects.create(name='bbq-with-bisk', used=True)
+ cls.test_name = OffTopicChannelName.objects.create(
+ name='lemons-lemonade-stand', used=False, active=True
+ )
+ cls.test_name_2 = OffTopicChannelName.objects.create(
+ name='bbq-with-bisk', used=True, active=True
+ )
+ cls.test_name_3 = OffTopicChannelName.objects.create(
+ name="frozen-with-iceman", used=True, active=False
+ )
def test_returns_name_in_list(self):
"""Return all off-topic channel names."""
@@ -78,7 +85,8 @@ class ListTests(APISubdomainTestCase):
response.json(),
[
self.test_name.name,
- self.test_name_2.name
+ self.test_name_2.name,
+ self.test_name_3.name
]
)
@@ -97,7 +105,29 @@ class ListTests(APISubdomainTestCase):
response = self.client.get(f'{url}?random_items=2')
self.assertEqual(response.status_code, 200)
- self.assertEqual(response.json(), [self.test_name.name, self.test_name_2.name])
+ self.assertEqual(response.json(), [self.test_name.name, self.test_name_3.name])
+
+ def test_returns_inactive_ot_names(self):
+ """Return inactive off topic names."""
+ url = reverse('bot:offtopicchannelname-list', host="api")
+ response = self.client.get(f"{url}?active=false")
+
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(
+ response.json(),
+ [self.test_name_3.name]
+ )
+
+ def test_returns_active_ot_names(self):
+ """Return active off topic names."""
+ url = reverse('bot:offtopicchannelname-list', host="api")
+ response = self.client.get(f"{url}?active=true")
+
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(
+ response.json(),
+ [self.test_name.name, self.test_name_2.name]
+ )
class CreationTests(APISubdomainTestCase):
diff --git a/pydis_site/apps/api/viewsets/bot/off_topic_channel_name.py b/pydis_site/apps/api/viewsets/bot/off_topic_channel_name.py
index 194e96d2..18ee84ea 100644
--- a/pydis_site/apps/api/viewsets/bot/off_topic_channel_name.py
+++ b/pydis_site/apps/api/viewsets/bot/off_topic_channel_name.py
@@ -1,8 +1,8 @@
from django.db.models import Case, Value, When
from django.db.models.query import QuerySet
-from django.http.request import HttpRequest
from django.shortcuts import get_object_or_404
from rest_framework.exceptions import ParseError
+from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.status import HTTP_201_CREATED
from rest_framework.viewsets import ModelViewSet
@@ -68,11 +68,11 @@ class OffTopicChannelNameViewSet(ModelViewSet):
name = self.kwargs[self.lookup_field]
return get_object_or_404(queryset, name=name)
- def get_queryset(self) -> QuerySet:
+ def get_queryset(self, **kwargs) -> QuerySet:
"""Returns a queryset that covers the entire OffTopicChannelName table."""
- return OffTopicChannelName.objects.all()
+ return OffTopicChannelName.objects.filter(**kwargs)
- def create(self, request: HttpRequest, *args, **kwargs) -> Response:
+ def create(self, request: Request, *args, **kwargs) -> Response:
"""
DRF method for creating a new OffTopicChannelName.
@@ -90,7 +90,7 @@ class OffTopicChannelNameViewSet(ModelViewSet):
'name': ["This query parameter is required."]
})
- def list(self, request: HttpRequest, *args, **kwargs) -> Response:
+ def list(self, request: Request, *args, **kwargs) -> Response:
"""
DRF method for listing OffTopicChannelName entries.
@@ -132,4 +132,11 @@ class OffTopicChannelNameViewSet(ModelViewSet):
serialized = self.serializer_class(queryset, many=True)
return Response(serialized.data)
- return super().list(self, request)
+ params = {}
+
+ if active_param := request.query_params.get("active"):
+ params["active"] = active_param.lower() == "true"
+
+ queryset = self.get_queryset(**params)
+ serialized = self.serializer_class(queryset, many=True)
+ return Response(serialized.data)