diff options
author | 2020-10-08 00:17:47 +0200 | |
---|---|---|
committer | 2020-10-08 01:08:41 +0200 | |
commit | 484eba7715fcbcc195d66f5a60ff56c8167ecf0e (patch) | |
tree | 34348bcf83f0b750cbfeefdd5ee7440999b0c99e /pydis_site | |
parent | Reduce metricity db setup script and API response to the bare necessities. (diff) |
Broke out metricity connection into an
abstraction and added metricity endpoint unit
tests.
Diffstat (limited to 'pydis_site')
-rw-r--r-- | pydis_site/apps/api/models/bot/metricity.py | 42 | ||||
-rw-r--r-- | pydis_site/apps/api/tests/test_users.py | 58 | ||||
-rw-r--r-- | pydis_site/apps/api/viewsets/bot/user.py | 17 |
3 files changed, 109 insertions, 8 deletions
diff --git a/pydis_site/apps/api/models/bot/metricity.py b/pydis_site/apps/api/models/bot/metricity.py new file mode 100644 index 00000000..25b42fa2 --- /dev/null +++ b/pydis_site/apps/api/models/bot/metricity.py @@ -0,0 +1,42 @@ +from django.db import connections + + +class NotFound(Exception): + """Raised when an entity cannot be found.""" + + pass + + +class Metricity: + """Abstraction for a connection to the metricity database.""" + + def __init__(self): + self.cursor = connections['metricity'].cursor() + + def __enter__(self): + return self + + def __exit__(self, *_): + self.cursor.close() + + def user(self, user_id: str) -> dict: + """Query a user's data.""" + columns = ["verified_at"] + query = f"SELECT {','.join(columns)} FROM users WHERE id = '%s'" + self.cursor.execute(query, [user_id]) + values = self.cursor.fetchone() + + if not values: + raise NotFound() + + return dict(zip(columns, values)) + + def total_messages(self, user_id: str) -> int: + """Query total number of messages for a user.""" + self.cursor.execute("SELECT COUNT(*) FROM messages WHERE author_id = '%s'", [user_id]) + values = self.cursor.fetchone() + + if not values: + raise NotFound() + + return values[0] diff --git a/pydis_site/apps/api/tests/test_users.py b/pydis_site/apps/api/tests/test_users.py index a02fce8a..76a21d3a 100644 --- a/pydis_site/apps/api/tests/test_users.py +++ b/pydis_site/apps/api/tests/test_users.py @@ -1,7 +1,10 @@ +from unittest.mock import patch + from django_hosts.resolvers import reverse from .base import APISubdomainTestCase from ..models import Role, User +from ..models.bot.metricity import NotFound class UnauthedUserAPITests(APISubdomainTestCase): @@ -170,3 +173,58 @@ class UserModelTests(APISubdomainTestCase): def test_correct_username_formatting(self): """Tests the username property with both name and discriminator formatted together.""" self.assertEqual(self.user_with_roles.username, "Test User with two roles#0001") + + +class UserMetricityTests(APISubdomainTestCase): + @classmethod + def setUpTestData(cls): + User.objects.create( + id=0, + name="Test user", + discriminator=1, + in_guild=True, + ) + + def test_get_metricity_data(self): + # Given + verified_at = "foo" + total_messages = 1 + self.mock_metricity_user(verified_at, total_messages) + + # When + url = reverse('bot:user-metricity-data', args=[0], host='api') + response = self.client.get(url) + + # Then + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json(), { + "verified_at": verified_at, + "total_messages": total_messages, + }) + + def test_no_metricity_user(self): + # Given + self.mock_no_metricity_user() + + # When + url = reverse('bot:user-metricity-data', args=[0], host='api') + response = self.client.get(url) + + # Then + self.assertEqual(response.status_code, 404) + + def mock_metricity_user(self, verified_at, total_messages): + patcher = patch("pydis_site.apps.api.viewsets.bot.user.Metricity") + self.metricity = patcher.start() + self.addCleanup(patcher.stop) + self.metricity = self.metricity.return_value.__enter__.return_value + self.metricity.user.return_value = dict(verified_at=verified_at) + self.metricity.total_messages.return_value = total_messages + + def mock_no_metricity_user(self): + patcher = patch("pydis_site.apps.api.viewsets.bot.user.Metricity") + self.metricity = patcher.start() + self.addCleanup(patcher.stop) + self.metricity = self.metricity.return_value.__enter__.return_value + self.metricity.user.side_effect = NotFound() + self.metricity.total_messages.side_effect = NotFound() diff --git a/pydis_site/apps/api/viewsets/bot/user.py b/pydis_site/apps/api/viewsets/bot/user.py index 1b1af841..352d77c0 100644 --- a/pydis_site/apps/api/viewsets/bot/user.py +++ b/pydis_site/apps/api/viewsets/bot/user.py @@ -1,4 +1,3 @@ -from django.db import connections from rest_framework import status from rest_framework.decorators import action from rest_framework.request import Request @@ -6,6 +5,7 @@ from rest_framework.response import Response from rest_framework.viewsets import ModelViewSet from rest_framework_bulk import BulkCreateModelMixin +from pydis_site.apps.api.models.bot.metricity import Metricity, NotFound from pydis_site.apps.api.models.bot.user import User from pydis_site.apps.api.serializers import UserSerializer @@ -142,10 +142,11 @@ class UserViewSet(BulkCreateModelMixin, ModelViewSet): def metricity_data(self, request: Request, pk: str = None) -> Response: """Request handler for metricity_data endpoint.""" user = self.get_object() - with connections['metricity'].cursor() as cursor: - data = {} - cursor.execute("SELECT verified_at FROM users WHERE id = '%s'", [user.id]) - data["verified_at"], = cursor.fetchone() - cursor.execute("SELECT COUNT(*) FROM messages WHERE author_id = '%s'", [user.id]) - data["total_messages"], = cursor.fetchone() - return Response(data, status=status.HTTP_200_OK) + with Metricity() as metricity: + try: + data = metricity.user(user.id) + data["total_messages"] = metricity.total_messages(user.id) + return Response(data, status=status.HTTP_200_OK) + except NotFound: + return Response(dict(detail="User not found in metricity"), + status=status.HTTP_404_NOT_FOUND) |