diff options
Diffstat (limited to '')
| -rw-r--r-- | pydis_site/apps/api/tests/test_infractions.py | 174 | ||||
| -rw-r--r-- | pydis_site/apps/api/viewsets/bot/infraction.py | 73 | 
2 files changed, 241 insertions, 6 deletions
| diff --git a/pydis_site/apps/api/tests/test_infractions.py b/pydis_site/apps/api/tests/test_infractions.py index 82b497aa..967698ff 100644 --- a/pydis_site/apps/api/tests/test_infractions.py +++ b/pydis_site/apps/api/tests/test_infractions.py @@ -1,3 +1,4 @@ +import datetime  from datetime import datetime as dt, timedelta, timezone  from unittest.mock import patch  from urllib.parse import quote @@ -16,7 +17,7 @@ class UnauthenticatedTests(APISubdomainTestCase):          self.client.force_authenticate(user=None)      def test_detail_lookup_returns_401(self): -        url = reverse('bot:infraction-detail', args=(5,), host='api') +        url = reverse('bot:infraction-detail', args=(6,), host='api')          response = self.client.get(url)          self.assertEqual(response.status_code, 401) @@ -34,7 +35,7 @@ class UnauthenticatedTests(APISubdomainTestCase):          self.assertEqual(response.status_code, 401)      def test_partial_update_returns_401(self): -        url = reverse('bot:infraction-detail', args=(5,), host='api') +        url = reverse('bot:infraction-detail', args=(6,), host='api')          response = self.client.patch(url, data={'reason': 'Have a nice day.'})          self.assertEqual(response.status_code, 401) @@ -44,7 +45,7 @@ class InfractionTests(APISubdomainTestCase):      @classmethod      def setUpTestData(cls):          cls.user = User.objects.create( -            id=5, +            id=6,              name='james',              discriminator=1,          ) @@ -64,6 +65,30 @@ class InfractionTests(APISubdomainTestCase):              reason='James is an ass, and we won\'t be working with him again.',              active=False          ) +        cls.mute_permanent = Infraction.objects.create( +            user_id=cls.user.id, +            actor_id=cls.user.id, +            type='mute', +            reason='He has a filthy mouth and I am his soap.', +            active=True, +            expires_at=None +        ) +        cls.superstar_expires_soon = Infraction.objects.create( +            user_id=cls.user.id, +            actor_id=cls.user.id, +            type='superstar', +            reason='This one doesn\'t matter anymore.', +            active=True, +            expires_at=datetime.datetime.utcnow() + datetime.timedelta(hours=5) +        ) +        cls.voiceban_expires_later = Infraction.objects.create( +            user_id=cls.user.id, +            actor_id=cls.user.id, +            type='voice_ban', +            reason='Jet engine mic', +            active=True, +            expires_at=datetime.datetime.utcnow() + datetime.timedelta(days=5) +        )      def test_list_all(self):          """Tests the list-view, which should be ordered by inserted_at (newest first).""" @@ -73,9 +98,12 @@ class InfractionTests(APISubdomainTestCase):          self.assertEqual(response.status_code, 200)          infractions = response.json() -        self.assertEqual(len(infractions), 2) -        self.assertEqual(infractions[0]['id'], self.ban_inactive.id) -        self.assertEqual(infractions[1]['id'], self.ban_hidden.id) +        self.assertEqual(len(infractions), 5) +        self.assertEqual(infractions[0]['id'], self.voiceban_expires_later.id) +        self.assertEqual(infractions[1]['id'], self.superstar_expires_soon.id) +        self.assertEqual(infractions[2]['id'], self.mute_permanent.id) +        self.assertEqual(infractions[3]['id'], self.ban_inactive.id) +        self.assertEqual(infractions[4]['id'], self.ban_hidden.id)      def test_filter_search(self):          url = reverse('bot:infraction-list', host='api') @@ -98,6 +126,140 @@ class InfractionTests(APISubdomainTestCase):          self.assertEqual(len(infractions), 1)          self.assertEqual(infractions[0]['id'], self.ban_hidden.id) +    def test_filter_permanent_false(self): +        url = reverse('bot:infraction-list', host='api') +        response = self.client.get(f'{url}?type=mute&permanent=false') + +        self.assertEqual(response.status_code, 200) +        infractions = response.json() + +        self.assertEqual(len(infractions), 0) + +    def test_filter_permanent_true(self): +        url = reverse('bot:infraction-list', host='api') +        response = self.client.get(f'{url}?type=mute&permanent=true') + +        self.assertEqual(response.status_code, 200) +        infractions = response.json() + +        self.assertEqual(infractions[0]['id'], self.mute_permanent.id) + +    def test_filter_after(self): +        url = reverse('bot:infraction-list', host='api') +        target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=5) +        response = self.client.get(f'{url}?type=superstar&expires_after={target_time.isoformat()}') + +        self.assertEqual(response.status_code, 200) +        infractions = response.json() +        self.assertEqual(len(infractions), 0) + +    def test_filter_before(self): +        url = reverse('bot:infraction-list', host='api') +        target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=5) +        response = self.client.get(f'{url}?type=superstar&expires_before={target_time.isoformat()}') + +        self.assertEqual(response.status_code, 200) +        infractions = response.json() +        self.assertEqual(len(infractions), 1) +        self.assertEqual(infractions[0]['id'], self.superstar_expires_soon.id) + +    def test_filter_after_invalid(self): +        url = reverse('bot:infraction-list', host='api') +        response = self.client.get(f'{url}?expires_after=gibberish') + +        self.assertEqual(response.status_code, 400) +        self.assertEqual(list(response.json())[0], "expires_after") + +    def test_filter_before_invalid(self): +        url = reverse('bot:infraction-list', host='api') +        response = self.client.get(f'{url}?expires_before=000000000') + +        self.assertEqual(response.status_code, 400) +        self.assertEqual(list(response.json())[0], "expires_before") + +    def test_after_before_before(self): +        url = reverse('bot:infraction-list', host='api') +        target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=4) +        target_time_late = datetime.datetime.utcnow() + datetime.timedelta(hours=6) +        response = self.client.get( +            f'{url}?expires_before={target_time_late.isoformat()}' +            f'&expires_after={target_time.isoformat()}' +        ) + +        self.assertEqual(response.status_code, 200) +        self.assertEqual(len(response.json()), 1) +        self.assertEqual(response.json()[0]["id"], self.superstar_expires_soon.id) + +    def test_after_after_before_invalid(self): +        url = reverse('bot:infraction-list', host='api') +        target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=5) +        target_time_late = datetime.datetime.utcnow() + datetime.timedelta(hours=9) +        response = self.client.get( +            f'{url}?expires_before={target_time.isoformat()}' +            f'&expires_after={target_time_late.isoformat()}' +        ) + +        self.assertEqual(response.status_code, 400) +        errors = list(response.json()) +        self.assertIn("expires_before", errors) +        self.assertIn("expires_after", errors) + +    def test_permanent_after_invalid(self): +        url = reverse('bot:infraction-list', host='api') +        target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=5) +        response = self.client.get(f'{url}?permanent=true&expires_after={target_time.isoformat()}') + +        self.assertEqual(response.status_code, 400) +        errors = list(response.json()) +        self.assertEqual("permanent", errors[0]) + +    def test_permanent_before_invalid(self): +        url = reverse('bot:infraction-list', host='api') +        target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=5) +        response = self.client.get(f'{url}?permanent=true&expires_before={target_time.isoformat()}') + +        self.assertEqual(response.status_code, 400) +        errors = list(response.json()) +        self.assertEqual("permanent", errors[0]) + +    def test_nonpermanent_before(self): +        url = reverse('bot:infraction-list', host='api') +        target_time = datetime.datetime.utcnow() + datetime.timedelta(hours=6) +        response = self.client.get( +            f'{url}?permanent=false&expires_before={target_time.isoformat()}' +        ) + +        self.assertEqual(response.status_code, 200) +        self.assertEqual(len(response.json()), 1) +        self.assertEqual(response.json()[0]["id"], self.superstar_expires_soon.id) + +    def test_filter_manytypes(self): +        url = reverse('bot:infraction-list', host='api') +        response = self.client.get(f'{url}?types=mute,ban') + +        self.assertEqual(response.status_code, 200) +        infractions = response.json() +        self.assertEqual(len(infractions), 3) + +    def test_types_type_invalid(self): +        url = reverse('bot:infraction-list', host='api') +        response = self.client.get(f'{url}?types=mute,ban&type=superstar') + +        self.assertEqual(response.status_code, 400) +        errors = list(response.json()) +        self.assertEqual("types", errors[0]) + +    def test_sort_expiresby(self): +        url = reverse('bot:infraction-list', host='api') +        response = self.client.get(f'{url}?ordering=expires_at&permanent=false') +        self.assertEqual(response.status_code, 200) +        infractions = response.json() + +        self.assertEqual(len(infractions), 3) +        self.assertEqual(infractions[0]['id'], self.superstar_expires_soon.id) +        self.assertEqual(infractions[1]['id'], self.voiceban_expires_later.id) +        self.assertEqual(infractions[2]['id'], self.ban_hidden.id) +      def test_returns_empty_for_no_match(self):          url = reverse('bot:infraction-list', host='api')          response = self.client.get(f'{url}?type=ban&search=poop') diff --git a/pydis_site/apps/api/viewsets/bot/infraction.py b/pydis_site/apps/api/viewsets/bot/infraction.py index bd512ddd..f8b0cb9d 100644 --- a/pydis_site/apps/api/viewsets/bot/infraction.py +++ b/pydis_site/apps/api/viewsets/bot/infraction.py @@ -1,3 +1,6 @@ +from datetime import datetime + +from django.db.models import QuerySet  from django.http.request import HttpRequest  from django_filters.rest_framework import DjangoFilterBackend  from rest_framework.decorators import action @@ -43,10 +46,17 @@ class InfractionViewSet(      - **offset** `int`: the initial index from which to return the results (default 0)      - **search** `str`: regular expression applied to the infraction's reason      - **type** `str`: the type of the infraction +    - **types** `str`: comma separated sequence of types to filter for      - **user__id** `int`: snowflake of the user to which the infraction was applied      - **ordering** `str`: comma-separated sequence of fields to order the returned results +    - **permanent** `bool`: whether or not to retrieve permanent infractions (default True) +    - **expires_after** `isodatetime`: the earliest expires_at time to return infractions for +    - **expires_before** `isodatetime`: the latest expires_at time to return infractions for      Invalid query parameters are ignored. +    Only one of `type` and `types` may be provided. If both `expires_before` and `expires_after` +    are provided, `expires_after` must come after `expires_before`. +    If `permanent` is provided and true, `expires_before` and `expires_after` must not be provided.      #### Response format      Response is paginated but the result is returned without any pagination metadata. @@ -156,6 +166,69 @@ class InfractionViewSet(          return Response(serializer.data) +    def get_queryset(self) -> QuerySet: +        """ +        Called to fetch the initial queryset, used to implement some of the more complex filters. + +        This provides the `permanent` and the `expires_gte` and `expires_lte` options. +        """ +        filter_permanent = self.request.query_params.get('permanent') +        additional_filters = {} +        if filter_permanent is not None: +            additional_filters['expires_at__isnull'] = filter_permanent.lower() == 'true' + +        filter_expires_after = self.request.query_params.get('expires_after') +        if filter_expires_after: +            try: +                additional_filters['expires_at__gte'] = datetime.fromisoformat( +                    filter_expires_after +                ) +            except ValueError: +                raise ValidationError({'expires_after': ['failed to convert to datetime']}) + +        filter_expires_before = self.request.query_params.get('expires_before') +        if filter_expires_before: +            try: +                additional_filters['expires_at__lte'] = datetime.fromisoformat( +                    filter_expires_before +                ) +            except ValueError: +                raise ValidationError({'expires_before': ['failed to convert to datetime']}) + +        if 'expires_at__lte' in additional_filters and 'expires_at__gte' in additional_filters: +            if additional_filters['expires_at__gte'] > additional_filters['expires_at__lte']: +                raise ValidationError({ +                    'expires_before': ['cannot be after expires_after'], +                    'expires_after': ['cannot be before expires_before'], +                }) + +        if ( +            ('expires_at__lte' in additional_filters or 'expires_at__gte' in additional_filters) +            and 'expires_at__isnull' in additional_filters +            and additional_filters['expires_at__isnull'] +        ): +            raise ValidationError({ +                'permanent': [ +                    'cannot filter for permanent infractions at the' +                    ' same time as expires_at or expires_before', +                ] +            }) + +        if filter_expires_before: +            # Filter out permanent infractions specifically if we want ones that will expire +            # before a given date +            additional_filters['expires_at__isnull'] = False + +        filter_types = self.request.query_params.get('types') +        if filter_types: +            if self.request.query_params.get('type'): +                raise ValidationError({ +                    'types': ['you must provide only one of "type" or "types"'], +                }) +            additional_filters['type__in'] = [i.strip() for i in filter_types.split(",")] + +        return self.queryset.filter(**additional_filters) +      @action(url_path='expanded', detail=False)      def list_expanded(self, *args, **kwargs) -> Response:          """ | 
