diff options
Diffstat (limited to 'pydis_site/apps')
| -rw-r--r-- | pydis_site/apps/api/migrations/0067_add_voice_ban_infraction_type.py | 18 | ||||
| -rw-r--r-- | pydis_site/apps/api/models/bot/infraction.py | 3 | ||||
| -rw-r--r-- | pydis_site/apps/api/models/bot/metricity.py | 42 | ||||
| -rw-r--r-- | pydis_site/apps/api/serializers.py | 2 | ||||
| -rw-r--r-- | pydis_site/apps/api/tests/test_users.py | 79 | ||||
| -rw-r--r-- | pydis_site/apps/api/viewsets/bot/user.py | 38 | 
6 files changed, 180 insertions, 2 deletions
diff --git a/pydis_site/apps/api/migrations/0067_add_voice_ban_infraction_type.py b/pydis_site/apps/api/migrations/0067_add_voice_ban_infraction_type.py new file mode 100644 index 00000000..9a940ff4 --- /dev/null +++ b/pydis_site/apps/api/migrations/0067_add_voice_ban_infraction_type.py @@ -0,0 +1,18 @@ +# Generated by Django 3.0.10 on 2020-10-10 16:08 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + +    dependencies = [ +        ('api', '0066_merge_20201003_0730'), +    ] + +    operations = [ +        migrations.AlterField( +            model_name='infraction', +            name='type', +            field=models.CharField(choices=[('note', 'Note'), ('warning', 'Warning'), ('watch', 'Watch'), ('mute', 'Mute'), ('kick', 'Kick'), ('ban', 'Ban'), ('superstar', 'Superstar'), ('voice_ban', 'Voice Ban')], help_text='The type of the infraction.', max_length=9), +        ), +    ] diff --git a/pydis_site/apps/api/models/bot/infraction.py b/pydis_site/apps/api/models/bot/infraction.py index 7660cbba..60c1e8dd 100644 --- a/pydis_site/apps/api/models/bot/infraction.py +++ b/pydis_site/apps/api/models/bot/infraction.py @@ -15,7 +15,8 @@ class Infraction(ModelReprMixin, models.Model):          ("mute", "Mute"),          ("kick", "Kick"),          ("ban", "Ban"), -        ("superstar", "Superstar") +        ("superstar", "Superstar"), +        ("voice_ban", "Voice Ban"),      )      inserted_at = models.DateTimeField(          default=timezone.now, diff --git a/pydis_site/apps/api/models/bot/metricity.py b/pydis_site/apps/api/models/bot/metricity.py new file mode 100644 index 00000000..25b42fa2 --- /dev/null +++ b/pydis_site/apps/api/models/bot/metricity.py @@ -0,0 +1,42 @@ +from django.db import connections + + +class NotFound(Exception): +    """Raised when an entity cannot be found.""" + +    pass + + +class Metricity: +    """Abstraction for a connection to the metricity database.""" + +    def __init__(self): +        self.cursor = connections['metricity'].cursor() + +    def __enter__(self): +        return self + +    def __exit__(self, *_): +        self.cursor.close() + +    def user(self, user_id: str) -> dict: +        """Query a user's data.""" +        columns = ["verified_at"] +        query = f"SELECT {','.join(columns)} FROM users WHERE id = '%s'" +        self.cursor.execute(query, [user_id]) +        values = self.cursor.fetchone() + +        if not values: +            raise NotFound() + +        return dict(zip(columns, values)) + +    def total_messages(self, user_id: str) -> int: +        """Query total number of messages for a user.""" +        self.cursor.execute("SELECT COUNT(*) FROM messages WHERE author_id = '%s'", [user_id]) +        values = self.cursor.fetchone() + +        if not values: +            raise NotFound() + +        return values[0] diff --git a/pydis_site/apps/api/serializers.py b/pydis_site/apps/api/serializers.py index 25c5c82e..10eb3839 100644 --- a/pydis_site/apps/api/serializers.py +++ b/pydis_site/apps/api/serializers.py @@ -167,7 +167,7 @@ class InfractionSerializer(ModelSerializer):              raise ValidationError({'expires_at': [f'{infr_type} infractions cannot expire.']})          hidden = attrs.get('hidden') -        if hidden and infr_type in ('superstar', 'warning'): +        if hidden and infr_type in ('superstar', 'warning', 'voice_ban'):              raise ValidationError({'hidden': [f'{infr_type} infractions cannot be hidden.']})          if not hidden and infr_type in ('note', ): diff --git a/pydis_site/apps/api/tests/test_users.py b/pydis_site/apps/api/tests/test_users.py index 825e4edb..72ffcb3c 100644 --- a/pydis_site/apps/api/tests/test_users.py +++ b/pydis_site/apps/api/tests/test_users.py @@ -1,7 +1,11 @@ +from unittest.mock import patch + +from django.core.exceptions import ObjectDoesNotExist  from django_hosts.resolvers import reverse  from .base import APISubdomainTestCase  from ..models import Role, User +from ..models.bot.metricity import NotFound  class UnauthedUserAPITests(APISubdomainTestCase): @@ -389,3 +393,78 @@ class UserPaginatorTests(APISubdomainTestCase):          url = reverse("bot:user-list", host="api")          response = self.client.get(url, {"page": 2}).json()          self.assertEqual(1, response["previous_page_no"]) + + +class UserMetricityTests(APISubdomainTestCase): +    @classmethod +    def setUpTestData(cls): +        User.objects.create( +            id=0, +            name="Test user", +            discriminator=1, +            in_guild=True, +        ) + +    def test_get_metricity_data(self): +        # Given +        verified_at = "foo" +        total_messages = 1 +        self.mock_metricity_user(verified_at, total_messages) + +        # When +        url = reverse('bot:user-metricity-data', args=[0], host='api') +        response = self.client.get(url) + +        # Then +        self.assertEqual(response.status_code, 200) +        self.assertEqual(response.json(), { +            "verified_at": verified_at, +            "total_messages": total_messages, +            "voice_banned": False, +        }) + +    def test_no_metricity_user(self): +        # Given +        self.mock_no_metricity_user() + +        # When +        url = reverse('bot:user-metricity-data', args=[0], host='api') +        response = self.client.get(url) + +        # Then +        self.assertEqual(response.status_code, 404) + +    def test_metricity_voice_banned(self): +        cases = [ +            {'exception': None, 'voice_banned': True}, +            {'exception': ObjectDoesNotExist, 'voice_banned': False}, +        ] + +        self.mock_metricity_user("foo", 1) + +        for case in cases: +            with self.subTest(exception=case['exception'], voice_banned=case['voice_banned']): +                with patch("pydis_site.apps.api.viewsets.bot.user.Infraction.objects.get") as p: +                    p.side_effect = case['exception'] + +                    url = reverse('bot:user-metricity-data', args=[0], host='api') +                    response = self.client.get(url) + +                    self.assertEqual(response.status_code, 200) +                    self.assertEqual(response.json()["voice_banned"], case["voice_banned"]) + +    def mock_metricity_user(self, verified_at, total_messages): +        patcher = patch("pydis_site.apps.api.viewsets.bot.user.Metricity") +        self.metricity = patcher.start() +        self.addCleanup(patcher.stop) +        self.metricity = self.metricity.return_value.__enter__.return_value +        self.metricity.user.return_value = dict(verified_at=verified_at) +        self.metricity.total_messages.return_value = total_messages + +    def mock_no_metricity_user(self): +        patcher = patch("pydis_site.apps.api.viewsets.bot.user.Metricity") +        self.metricity = patcher.start() +        self.addCleanup(patcher.stop) +        self.metricity = self.metricity.return_value.__enter__.return_value +        self.metricity.user.side_effect = NotFound() +        self.metricity.total_messages.side_effect = NotFound() diff --git a/pydis_site/apps/api/viewsets/bot/user.py b/pydis_site/apps/api/viewsets/bot/user.py index 3e4b627e..5205dc97 100644 --- a/pydis_site/apps/api/viewsets/bot/user.py +++ b/pydis_site/apps/api/viewsets/bot/user.py @@ -1,6 +1,7 @@  import typing  from collections import OrderedDict +from django.core.exceptions import ObjectDoesNotExist  from rest_framework import status  from rest_framework.decorators import action  from rest_framework.pagination import PageNumberPagination @@ -9,6 +10,8 @@ from rest_framework.response import Response  from rest_framework.serializers import ModelSerializer  from rest_framework.viewsets import ModelViewSet +from pydis_site.apps.api.models.bot.infraction import Infraction +from pydis_site.apps.api.models.bot.metricity import Metricity, NotFound  from pydis_site.apps.api.models.bot.user import User  from pydis_site.apps.api.serializers import UserSerializer @@ -101,6 +104,19 @@ class UserViewSet(ModelViewSet):      - 200: returned on success      - 404: if a user with the given `snowflake` could not be found +    ### GET /bot/users/<snowflake:int>/metricity_data +    Gets metricity data for a single user by ID. + +    #### Response format +    >>> { +    ...    "verified_at": "2020-10-06T21:54:23.540766", +    ...    "total_messages": 2 +    ...} + +    #### Status codes +    - 200: returned on success +    - 404: if a user with the given `snowflake` could not be found +      ### POST /bot/users      Adds a single or multiple new users.      The roles attached to the user(s) must be roles known by the site. @@ -221,3 +237,25 @@ class UserViewSet(ModelViewSet):          serializer.save()          return Response(serializer.data, status=status.HTTP_200_OK) + +    @action(detail=True) +    def metricity_data(self, request: Request, pk: str = None) -> Response: +        """Request handler for metricity_data endpoint.""" +        user = self.get_object() + +        try: +            Infraction.objects.get(user__id=user.id, active=True, type="voice_ban") +        except ObjectDoesNotExist: +            voice_banned = False +        else: +            voice_banned = True + +        with Metricity() as metricity: +            try: +                data = metricity.user(user.id) +                data["total_messages"] = metricity.total_messages(user.id) +                data["voice_banned"] = voice_banned +                return Response(data, status=status.HTTP_200_OK) +            except NotFound: +                return Response(dict(detail="User not found in metricity"), +                                status=status.HTTP_404_NOT_FOUND)  |