diff options
author | 2023-01-27 22:19:28 +0200 | |
---|---|---|
committer | 2023-01-27 22:19:28 +0200 | |
commit | d3eec93b36bd57c521e70b4001c74cb9756caf23 (patch) | |
tree | b43536a627b2b822af4b697b3e1cf39132475e23 /pydis_site/apps/api/serializers.py | |
parent | Rename delete_messages to the more generic remove_context (diff) |
Fix filter serializers validation to account for filterlist settings
Diffstat (limited to 'pydis_site/apps/api/serializers.py')
-rw-r--r-- | pydis_site/apps/api/serializers.py | 40 |
1 files changed, 20 insertions, 20 deletions
diff --git a/pydis_site/apps/api/serializers.py b/pydis_site/apps/api/serializers.py index eabca66e..83ab4584 100644 --- a/pydis_site/apps/api/serializers.py +++ b/pydis_site/apps/api/serializers.py @@ -1,4 +1,6 @@ """Converters from Django models to data interchange formats and back.""" +from typing import Any + from django.db.models.query import QuerySet from django.db.utils import IntegrityError from rest_framework.exceptions import NotFound @@ -220,31 +222,29 @@ def _create_filter_meta_extra_kwargs() -> dict[str, dict[str, bool]]: return extra_kwargs +def get_field_value(data: dict, field_name: str) -> Any: + """Get the value directly from the key, or from the filter list if it's missing or is None.""" + if data.get(field_name): + return data[field_name] + return getattr(data["filter_list"], field_name) + + class FilterSerializer(ModelSerializer): """A class providing (de-)serialization of `Filter` instances.""" def validate(self, data: dict) -> dict: - """Perform infraction data + allow and disallowed lists validation.""" - if ( - data.get('infraction_reason') or data.get('infraction_duration') - ) and not data.get('infraction_type'): - raise ValidationError("Infraction type is required with infraction duration or reason") - + """Perform infraction data + allowed and disallowed lists validation.""" if ( - data.get('disabled_channels') is not None - and data.get('enabled_channels') is not None + (get_field_value(data, "infraction_reason") or get_field_value(data, "infraction_duration")) + and get_field_value(data, "infraction_type") == "NONE" ): - channels_collection = data['disabled_channels'] + data['enabled_channels'] - if len(channels_collection) != len(set(channels_collection)): - raise ValidationError("Enabled and Disabled channels lists contain duplicates.") + raise ValidationError("Infraction type is required with infraction duration or reason.") - if ( - data.get('disabled_categories') is not None - and data.get('enabled_categories') is not None - ): - categories_collection = data['disabled_categories'] + data['enabled_categories'] - if len(categories_collection) != len(set(categories_collection)): - raise ValidationError("Enabled and Disabled categories lists contain duplicates.") + if set(get_field_value(data, "disabled_channels")) & set(get_field_value(data, "enabled_channels")): + raise ValidationError("You can't have the same value in both enabled and disabled channels lists.") + + if set(get_field_value(data, "disabled_categories")) & set(get_field_value(data, "enabled_categories")): + raise ValidationError("You can't have the same value in both enabled and disabled categories lists.") return data @@ -318,8 +318,8 @@ class FilterListSerializer(ModelSerializer): raise ValidationError("Enabled and Disabled channels lists contain duplicates.") if ( - data.get('disabled_categories') is not None - and data.get('enabled_categories') is not None + data.get('disabled_categories') is not None + and data.get('enabled_categories') is not None ): categories_collection = data['disabled_categories'] + data['enabled_categories'] if len(categories_collection) != len(set(categories_collection)): |