diff options
author | 2022-02-20 17:43:54 +0100 | |
---|---|---|
committer | 2022-02-21 22:24:00 +0100 | |
commit | 26e4f518c874cafdee594c08c01d610e88528dc7 (patch) | |
tree | 637a3d87c4e73c364b101654dca505449990b11d | |
parent | Merge pull request #624 from python-discord/content/update-help-channel-timing (diff) |
Prevent race condition with duplicate infractions
DRF's `UniqueTogetherValidator` validates uniqueness by querying the
database before running the actual insert. This is not, has not, and
will never be valid, unless you happen to run a single worker, on a
single thread, and your single worker running on a single thread is the
only client for the database, in which case it may be valid. For any
other cases, it's invalid, and it has never been valid. PostgreSQL spits
out an `IntegrityError` for us if we have a duplicate entry, and
PostgreSQL is the only valid and correct thing to trust here.
The `UniqueTogetherValidator` is removed, and an existing test case
calling into this validator to check for uniqueness is removed.
Furthermore, to work around a Django quirk, `transaction.atomic()` is
added to prevent one `subTest` from messing with another.
Closes #665.
-rw-r--r-- | pydis_site/apps/api/serializers.py | 7 | ||||
-rw-r--r-- | pydis_site/apps/api/tests/test_infractions.py | 77 | ||||
-rw-r--r-- | pydis_site/apps/api/viewsets/bot/infraction.py | 18 |
3 files changed, 50 insertions, 52 deletions
diff --git a/pydis_site/apps/api/serializers.py b/pydis_site/apps/api/serializers.py index 4a702d61..745aff42 100644 --- a/pydis_site/apps/api/serializers.py +++ b/pydis_site/apps/api/serializers.py @@ -156,13 +156,6 @@ class InfractionSerializer(ModelSerializer): 'hidden', 'dm_sent' ) - validators = [ - UniqueTogetherValidator( - queryset=Infraction.objects.filter(active=True), - fields=['user', 'type', 'active'], - message='This user already has an active infraction of this type.', - ) - ] def validate(self, attrs: dict) -> dict: """Validate data constraints for the given data and abort if it is invalid.""" diff --git a/pydis_site/apps/api/tests/test_infractions.py b/pydis_site/apps/api/tests/test_infractions.py index b3dd16ee..aa0604f6 100644 --- a/pydis_site/apps/api/tests/test_infractions.py +++ b/pydis_site/apps/api/tests/test_infractions.py @@ -3,6 +3,7 @@ from datetime import datetime as dt, timedelta, timezone from unittest.mock import patch from urllib.parse import quote +from django.db import transaction from django.db.utils import IntegrityError from django.urls import reverse @@ -492,6 +493,7 @@ class CreationTests(AuthenticatedAPITestCase): ) for infraction_type, hidden in restricted_types: + # https://stackoverflow.com/a/23326971 with self.subTest(infraction_type=infraction_type): invalid_infraction = { 'user': self.user.id, @@ -516,37 +518,38 @@ class CreationTests(AuthenticatedAPITestCase): for infraction_type in active_infraction_types: with self.subTest(infraction_type=infraction_type): - first_active_infraction = { - 'user': self.user.id, - 'actor': self.user.id, - 'type': infraction_type, - 'reason': 'Take me on!', - 'active': True, - 'expires_at': '2019-10-04T12:52:00+00:00' - } - - # Post the first active infraction of a type and confirm it's accepted. - first_response = self.client.post(url, data=first_active_infraction) - self.assertEqual(first_response.status_code, 201) + with transaction.atomic(): + first_active_infraction = { + 'user': self.user.id, + 'actor': self.user.id, + 'type': infraction_type, + 'reason': 'Take me on!', + 'active': True, + 'expires_at': '2019-10-04T12:52:00+00:00' + } - second_active_infraction = { - 'user': self.user.id, - 'actor': self.user.id, - 'type': infraction_type, - 'reason': 'Take on me!', - 'active': True, - 'expires_at': '2019-10-04T12:52:00+00:00' - } - second_response = self.client.post(url, data=second_active_infraction) - self.assertEqual(second_response.status_code, 400) - self.assertEqual( - second_response.json(), - { - 'non_field_errors': [ - 'This user already has an active infraction of this type.' - ] + # Post the first active infraction of a type and confirm it's accepted. + first_response = self.client.post(url, data=first_active_infraction) + self.assertEqual(first_response.status_code, 201) + + second_active_infraction = { + 'user': self.user.id, + 'actor': self.user.id, + 'type': infraction_type, + 'reason': 'Take on me!', + 'active': True, + 'expires_at': '2019-10-04T12:52:00+00:00' } - ) + second_response = self.client.post(url, data=second_active_infraction) + self.assertEqual(second_response.status_code, 400) + self.assertEqual( + second_response.json(), + { + 'non_field_errors': [ + 'This user already has an active infraction of this type.' + ] + } + ) def test_returns_201_for_second_active_infraction_of_different_type(self): """Test if the API accepts a second active infraction of a different type than the first.""" @@ -811,22 +814,6 @@ class SerializerTests(AuthenticatedAPITestCase): self.assertTrue(serializer.is_valid(), msg=serializer.errors) - def test_validation_error_if_active_duplicate(self): - self.create_infraction('ban', active=True) - instance = self.create_infraction('ban', active=False) - - data = {'active': True} - serializer = InfractionSerializer(instance, data=data, partial=True) - - if not serializer.is_valid(): - self.assertIn('non_field_errors', serializer.errors) - - code = serializer.errors['non_field_errors'][0].code - msg = f'Expected failure on unique validator but got {serializer.errors}' - self.assertEqual(code, 'unique', msg=msg) - else: # pragma: no cover - self.fail('Validation unexpectedly succeeded.') - def test_is_valid_for_new_active_infraction(self): self.create_infraction('ban', active=False) diff --git a/pydis_site/apps/api/viewsets/bot/infraction.py b/pydis_site/apps/api/viewsets/bot/infraction.py index 8a48ed1f..31e8ba40 100644 --- a/pydis_site/apps/api/viewsets/bot/infraction.py +++ b/pydis_site/apps/api/viewsets/bot/infraction.py @@ -1,5 +1,6 @@ from datetime import datetime +from django.db import IntegrityError from django.db.models import QuerySet from django.http.request import HttpRequest from django_filters.rest_framework import DjangoFilterBackend @@ -271,3 +272,20 @@ class InfractionViewSet( """ self.serializer_class = ExpandedInfractionSerializer return self.partial_update(*args, **kwargs) + + def create(self, request: HttpRequest, *args, **kwargs) -> Response: + """ + Create an infraction for a target user. + + Called by the Django Rest Framework in response to the corresponding HTTP request. + """ + try: + return super().create(request, *args, **kwargs) + except IntegrityError: + raise ValidationError( + { + 'non_field_errors': [ + 'This user already has an active infraction of this type.', + ] + } + ) |