aboutsummaryrefslogtreecommitdiffstats
path: root/pydis_site
diff options
context:
space:
mode:
authorGravatar ChrisJL <[email protected]>2021-06-04 17:38:52 +0100
committerGravatar GitHub <[email protected]>2021-06-04 17:38:52 +0100
commit7d8182988734d9641597f85572ad22fe6f082eed (patch)
tree1dafe9f8b85d856657a1e34db4dd40cc0aeece8f /pydis_site
parentMerge pull request #516 from soham4abc/main (diff)
parentMerge branch 'main' into bast0006-new-infraction-filters (diff)
Merge pull request #510 from bast0006/bast0006-new-infraction-filters
Add new infraction filters for the infraction rescheduler
Diffstat (limited to 'pydis_site')
-rw-r--r--pydis_site/apps/api/tests/test_infractions.py174
-rw-r--r--pydis_site/apps/api/viewsets/bot/infraction.py73
2 files changed, 241 insertions, 6 deletions
diff --git a/pydis_site/apps/api/tests/test_infractions.py b/pydis_site/apps/api/tests/test_infractions.py
index 82b497aa..967698ff 100644
--- a/pydis_site/apps/api/tests/test_infractions.py
+++ b/pydis_site/apps/api/tests/test_infractions.py
@@ -1,3 +1,4 @@
+import datetime
from datetime import datetime as dt, timedelta, timezone
from unittest.mock import patch
from urllib.parse import quote
@@ -16,7 +17,7 @@ class UnauthenticatedTests(APISubdomainTestCase):
self.client.force_authenticate(user=None)
def test_detail_lookup_returns_401(self):
- url = reverse('bot:infraction-detail', args=(5,), host='api')
+ url = reverse('bot:infraction-detail', args=(6,), host='api')
response = self.client.get(url)
self.assertEqual(response.status_code, 401)
@@ -34,7 +35,7 @@ class UnauthenticatedTests(APISubdomainTestCase):
self.assertEqual(response.status_code, 401)
def test_partial_update_returns_401(self):
- url = reverse('bot:infraction-detail', args=(5,), host='api')
+ url = reverse('bot:infraction-detail', args=(6,), host='api')
response = self.client.patch(url, data={'reason': 'Have a nice day.'})
self.assertEqual(response.status_code, 401)
@@ -44,7 +45,7 @@ class InfractionTests(APISubdomainTestCase):
@classmethod
def setUpTestData(cls):
cls.user = User.objects.create(
- id=5,
+ id=6,
name='james',
discriminator=1,
)
@@ -64,6 +65,30 @@ class InfractionTests(APISubdomainTestCase):
reason='James is an ass, and we won\'t be working with him again.',
active=False
)
+ cls.mute_permanent = Infraction.objects.create(
+ user_id=cls.user.id,
+ actor_id=cls.user.id,
+ type='mute',
+ reason='He has a filthy mouth and I am his soap.',
+ active=True,
+ expires_at=None
+ )
+ cls.superstar_expires_soon = Infraction.objects.create(
+ user_id=cls.user.id,
+ actor_id=cls.user.id,
+ type='superstar',
+ reason='This one doesn\'t matter anymore.',
+ active=True,
+ expires_at=datetime.datetime.utcnow() + datetime.timedelta(hours=5)
+ )
+ cls.voiceban_expires_later = Infraction.objects.create(
+ user_id=cls.user.id,
+ actor_id=cls.user.id,
+ type='voice_ban',
+ reason='Jet engine mic',
+ active=True,
+ expires_at=datetime.datetime.utcnow() + datetime.timedelta(days=5)
+ )
def test_list_all(self):
"""Tests the list-view, which should be ordered by inserted_at (newest first)."""
@@ -73,9 +98,12 @@ class InfractionTests(APISubdomainTestCase):
self.assertEqual(response.status_code, 200)
infractions = response.json()
- self.assertEqual(len(infractions), 2)
- self.assertEqual(infractions[0]['id'], self.ban_inactive.id)
- self.assertEqual(infractions[1]['id'], self.ban_hidden.id)
+ self.assertEqual(len(infractions), 5)
+ self.assertEqual(infractions[0]['id'], self.voiceban_expires_later.id)
+ self.assertEqual(infractions[1]['id'], self.superstar_expires_soon.id)
+ self.assertEqual(infractions[2]['id'], self.mute_permanent.id)
+ self.assertEqual(infractions[3]['id'], self.ban_inactive.id)
+ self.assertEqual(infractions[4]['id'], self.ban_hidden.id)
def test_filter_search(self):
url = reverse('bot:infraction-list', host='api')
@@ -98,6 +126,140 @@ class InfractionTests(APISubdomainTestCase):
self.assertEqual(len(infractions), 1)
self.assertEqual(infractions[0]['id'], self.ban_hidden.id)
+ def test_filter_permanent_false(self):
+ url = reverse('bot:infraction-list', host='api')
+ response = self.client.get(f'{url}?type=mute&permanent=false')
+
+ self.assertEqual(response.status_code, 200)
+ infractions = response.json()
+
+ self.assertEqual(len(infractions), 0)
+
+ def test_filter_permanent_true(self):
+ url = reverse('bot:infraction-list', host='api')
+ response = self.client.get(f'{url}?type=mute&permanent=true')
+
+ self.assertEqual(response.status_code, 200)
+ infractions = response.json()
+
+ self.assertEqual(infractions[0]['id'], self.mute_permanent.id)
+
+ def test_filter_after(self):
+ url = reverse('bot:infraction-list', host='api')
+ target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=5)
+ response = self.client.get(f'{url}?type=superstar&expires_after={target_time.isoformat()}')
+
+ self.assertEqual(response.status_code, 200)
+ infractions = response.json()
+ self.assertEqual(len(infractions), 0)
+
+ def test_filter_before(self):
+ url = reverse('bot:infraction-list', host='api')
+ target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=5)
+ response = self.client.get(f'{url}?type=superstar&expires_before={target_time.isoformat()}')
+
+ self.assertEqual(response.status_code, 200)
+ infractions = response.json()
+ self.assertEqual(len(infractions), 1)
+ self.assertEqual(infractions[0]['id'], self.superstar_expires_soon.id)
+
+ def test_filter_after_invalid(self):
+ url = reverse('bot:infraction-list', host='api')
+ response = self.client.get(f'{url}?expires_after=gibberish')
+
+ self.assertEqual(response.status_code, 400)
+ self.assertEqual(list(response.json())[0], "expires_after")
+
+ def test_filter_before_invalid(self):
+ url = reverse('bot:infraction-list', host='api')
+ response = self.client.get(f'{url}?expires_before=000000000')
+
+ self.assertEqual(response.status_code, 400)
+ self.assertEqual(list(response.json())[0], "expires_before")
+
+ def test_after_before_before(self):
+ url = reverse('bot:infraction-list', host='api')
+ target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=4)
+ target_time_late = datetime.datetime.utcnow() + datetime.timedelta(hours=6)
+ response = self.client.get(
+ f'{url}?expires_before={target_time_late.isoformat()}'
+ f'&expires_after={target_time.isoformat()}'
+ )
+
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(len(response.json()), 1)
+ self.assertEqual(response.json()[0]["id"], self.superstar_expires_soon.id)
+
+ def test_after_after_before_invalid(self):
+ url = reverse('bot:infraction-list', host='api')
+ target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=5)
+ target_time_late = datetime.datetime.utcnow() + datetime.timedelta(hours=9)
+ response = self.client.get(
+ f'{url}?expires_before={target_time.isoformat()}'
+ f'&expires_after={target_time_late.isoformat()}'
+ )
+
+ self.assertEqual(response.status_code, 400)
+ errors = list(response.json())
+ self.assertIn("expires_before", errors)
+ self.assertIn("expires_after", errors)
+
+ def test_permanent_after_invalid(self):
+ url = reverse('bot:infraction-list', host='api')
+ target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=5)
+ response = self.client.get(f'{url}?permanent=true&expires_after={target_time.isoformat()}')
+
+ self.assertEqual(response.status_code, 400)
+ errors = list(response.json())
+ self.assertEqual("permanent", errors[0])
+
+ def test_permanent_before_invalid(self):
+ url = reverse('bot:infraction-list', host='api')
+ target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=5)
+ response = self.client.get(f'{url}?permanent=true&expires_before={target_time.isoformat()}')
+
+ self.assertEqual(response.status_code, 400)
+ errors = list(response.json())
+ self.assertEqual("permanent", errors[0])
+
+ def test_nonpermanent_before(self):
+ url = reverse('bot:infraction-list', host='api')
+ target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=6)
+ response = self.client.get(
+ f'{url}?permanent=false&expires_before={target_time.isoformat()}'
+ )
+
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(len(response.json()), 1)
+ self.assertEqual(response.json()[0]["id"], self.superstar_expires_soon.id)
+
+ def test_filter_manytypes(self):
+ url = reverse('bot:infraction-list', host='api')
+ response = self.client.get(f'{url}?types=mute,ban')
+
+ self.assertEqual(response.status_code, 200)
+ infractions = response.json()
+ self.assertEqual(len(infractions), 3)
+
+ def test_types_type_invalid(self):
+ url = reverse('bot:infraction-list', host='api')
+ response = self.client.get(f'{url}?types=mute,ban&type=superstar')
+
+ self.assertEqual(response.status_code, 400)
+ errors = list(response.json())
+ self.assertEqual("types", errors[0])
+
+ def test_sort_expiresby(self):
+ url = reverse('bot:infraction-list', host='api')
+ response = self.client.get(f'{url}?ordering=expires_at&permanent=false')
+ self.assertEqual(response.status_code, 200)
+ infractions = response.json()
+
+ self.assertEqual(len(infractions), 3)
+ self.assertEqual(infractions[0]['id'], self.superstar_expires_soon.id)
+ self.assertEqual(infractions[1]['id'], self.voiceban_expires_later.id)
+ self.assertEqual(infractions[2]['id'], self.ban_hidden.id)
+
def test_returns_empty_for_no_match(self):
url = reverse('bot:infraction-list', host='api')
response = self.client.get(f'{url}?type=ban&search=poop')
diff --git a/pydis_site/apps/api/viewsets/bot/infraction.py b/pydis_site/apps/api/viewsets/bot/infraction.py
index bd512ddd..f8b0cb9d 100644
--- a/pydis_site/apps/api/viewsets/bot/infraction.py
+++ b/pydis_site/apps/api/viewsets/bot/infraction.py
@@ -1,3 +1,6 @@
+from datetime import datetime
+
+from django.db.models import QuerySet
from django.http.request import HttpRequest
from django_filters.rest_framework import DjangoFilterBackend
from rest_framework.decorators import action
@@ -43,10 +46,17 @@ class InfractionViewSet(
- **offset** `int`: the initial index from which to return the results (default 0)
- **search** `str`: regular expression applied to the infraction's reason
- **type** `str`: the type of the infraction
+ - **types** `str`: comma separated sequence of types to filter for
- **user__id** `int`: snowflake of the user to which the infraction was applied
- **ordering** `str`: comma-separated sequence of fields to order the returned results
+ - **permanent** `bool`: whether or not to retrieve permanent infractions (default True)
+ - **expires_after** `isodatetime`: the earliest expires_at time to return infractions for
+ - **expires_before** `isodatetime`: the latest expires_at time to return infractions for
Invalid query parameters are ignored.
+ Only one of `type` and `types` may be provided. If both `expires_before` and `expires_after`
+ are provided, `expires_after` must come after `expires_before`.
+ If `permanent` is provided and true, `expires_before` and `expires_after` must not be provided.
#### Response format
Response is paginated but the result is returned without any pagination metadata.
@@ -156,6 +166,69 @@ class InfractionViewSet(
return Response(serializer.data)
+ def get_queryset(self) -> QuerySet:
+ """
+ Called to fetch the initial queryset, used to implement some of the more complex filters.
+
+ This provides the `permanent` and the `expires_gte` and `expires_lte` options.
+ """
+ filter_permanent = self.request.query_params.get('permanent')
+ additional_filters = {}
+ if filter_permanent is not None:
+ additional_filters['expires_at__isnull'] = filter_permanent.lower() == 'true'
+
+ filter_expires_after = self.request.query_params.get('expires_after')
+ if filter_expires_after:
+ try:
+ additional_filters['expires_at__gte'] = datetime.fromisoformat(
+ filter_expires_after
+ )
+ except ValueError:
+ raise ValidationError({'expires_after': ['failed to convert to datetime']})
+
+ filter_expires_before = self.request.query_params.get('expires_before')
+ if filter_expires_before:
+ try:
+ additional_filters['expires_at__lte'] = datetime.fromisoformat(
+ filter_expires_before
+ )
+ except ValueError:
+ raise ValidationError({'expires_before': ['failed to convert to datetime']})
+
+ if 'expires_at__lte' in additional_filters and 'expires_at__gte' in additional_filters:
+ if additional_filters['expires_at__gte'] > additional_filters['expires_at__lte']:
+ raise ValidationError({
+ 'expires_before': ['cannot be after expires_after'],
+ 'expires_after': ['cannot be before expires_before'],
+ })
+
+ if (
+ ('expires_at__lte' in additional_filters or 'expires_at__gte' in additional_filters)
+ and 'expires_at__isnull' in additional_filters
+ and additional_filters['expires_at__isnull']
+ ):
+ raise ValidationError({
+ 'permanent': [
+ 'cannot filter for permanent infractions at the'
+ ' same time as expires_at or expires_before',
+ ]
+ })
+
+ if filter_expires_before:
+ # Filter out permanent infractions specifically if we want ones that will expire
+ # before a given date
+ additional_filters['expires_at__isnull'] = False
+
+ filter_types = self.request.query_params.get('types')
+ if filter_types:
+ if self.request.query_params.get('type'):
+ raise ValidationError({
+ 'types': ['you must provide only one of "type" or "types"'],
+ })
+ additional_filters['type__in'] = [i.strip() for i in filter_types.split(",")]
+
+ return self.queryset.filter(**additional_filters)
+
@action(url_path='expanded', detail=False)
def list_expanded(self, *args, **kwargs) -> Response:
"""