diff options
Diffstat (limited to '')
| -rw-r--r-- | pydis_site/apps/api/viewsets/bot/infraction.py | 42 | 
1 files changed, 42 insertions, 0 deletions
| diff --git a/pydis_site/apps/api/viewsets/bot/infraction.py b/pydis_site/apps/api/viewsets/bot/infraction.py index bd512ddd..b0c7d332 100644 --- a/pydis_site/apps/api/viewsets/bot/infraction.py +++ b/pydis_site/apps/api/viewsets/bot/infraction.py @@ -1,3 +1,6 @@ +from datetime import datetime + +from django.db.models import QuerySet  from django.http.request import HttpRequest  from django_filters.rest_framework import DjangoFilterBackend  from rest_framework.decorators import action @@ -43,8 +46,12 @@ class InfractionViewSet(      - **offset** `int`: the initial index from which to return the results (default 0)      - **search** `str`: regular expression applied to the infraction's reason      - **type** `str`: the type of the infraction +    - **types** `str`: comma separated sequence of types to filter for      - **user__id** `int`: snowflake of the user to which the infraction was applied      - **ordering** `str`: comma-separated sequence of fields to order the returned results +    - **permanent** `bool`: whether or not to retrieve permanent infractions (default True) +    - **expires_after** `isodatetime`: the earliest expires_at time to return infractions for +    - **expires_before** `isodatetime`: the latest expires_at time to return infractions for      Invalid query parameters are ignored. @@ -156,6 +163,41 @@ class InfractionViewSet(          return Response(serializer.data) +    def get_queryset(self) -> QuerySet: +        """ +        Called to fetch the initial queryset, used to implement some of the more complex filters. + +        This provides the `permanent` and the `expires_gte` and `expires_lte` options. +        """ +        filter_permanent = self.request.query_params.get('permanent') +        additional_filters = {} +        if filter_permanent is not None: +            additional_filters['expires_at__isnull'] = filter_permanent.lower() == 'true' + +        filter_expires_after = self.request.query_params.get('expires_after') +        if filter_expires_after: +            try: +                additional_filters['expires_at__gte'] = datetime.fromisoformat( +                    filter_expires_after +                ) +            except ValueError: +                raise ValidationError({'expires_after': ['failed to convert to datetime']}) + +        filter_expires_before = self.request.query_params.get('expires_before') +        if filter_expires_before: +            try: +                additional_filters['expires_at__lte'] = datetime.fromisoformat( +                    filter_expires_before +                ) +            except ValueError: +                raise ValidationError({'expires_before': ['failed to convert to datetime']}) + +        filter_types = self.request.query_params.get('types') +        if filter_types: +            additional_filters['type__in'] = [i.strip() for i in filter_types.split(",")] + +        return self.queryset.filter(**additional_filters) +      @action(url_path='expanded', detail=False)      def list_expanded(self, *args, **kwargs) -> Response:          """ | 
