aboutsummaryrefslogtreecommitdiffstats
path: root/pydis_site/apps
diff options
context:
space:
mode:
authorGravatar Lucas Lindström <[email protected]>2020-10-08 00:17:47 +0200
committerGravatar Lucas Lindström <[email protected]>2020-10-08 01:08:41 +0200
commit484eba7715fcbcc195d66f5a60ff56c8167ecf0e (patch)
tree34348bcf83f0b750cbfeefdd5ee7440999b0c99e /pydis_site/apps
parentReduce metricity db setup script and API response to the bare necessities. (diff)
Broke out metricity connection into an
abstraction and added metricity endpoint unit tests.
Diffstat (limited to 'pydis_site/apps')
-rw-r--r--pydis_site/apps/api/models/bot/metricity.py42
-rw-r--r--pydis_site/apps/api/tests/test_users.py58
-rw-r--r--pydis_site/apps/api/viewsets/bot/user.py17
3 files changed, 109 insertions, 8 deletions
diff --git a/pydis_site/apps/api/models/bot/metricity.py b/pydis_site/apps/api/models/bot/metricity.py
new file mode 100644
index 00000000..25b42fa2
--- /dev/null
+++ b/pydis_site/apps/api/models/bot/metricity.py
@@ -0,0 +1,42 @@
+from django.db import connections
+
+
+class NotFound(Exception):
+ """Raised when an entity cannot be found."""
+
+ pass
+
+
+class Metricity:
+ """Abstraction for a connection to the metricity database."""
+
+ def __init__(self):
+ self.cursor = connections['metricity'].cursor()
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, *_):
+ self.cursor.close()
+
+ def user(self, user_id: str) -> dict:
+ """Query a user's data."""
+ columns = ["verified_at"]
+ query = f"SELECT {','.join(columns)} FROM users WHERE id = '%s'"
+ self.cursor.execute(query, [user_id])
+ values = self.cursor.fetchone()
+
+ if not values:
+ raise NotFound()
+
+ return dict(zip(columns, values))
+
+ def total_messages(self, user_id: str) -> int:
+ """Query total number of messages for a user."""
+ self.cursor.execute("SELECT COUNT(*) FROM messages WHERE author_id = '%s'", [user_id])
+ values = self.cursor.fetchone()
+
+ if not values:
+ raise NotFound()
+
+ return values[0]
diff --git a/pydis_site/apps/api/tests/test_users.py b/pydis_site/apps/api/tests/test_users.py
index a02fce8a..76a21d3a 100644
--- a/pydis_site/apps/api/tests/test_users.py
+++ b/pydis_site/apps/api/tests/test_users.py
@@ -1,7 +1,10 @@
+from unittest.mock import patch
+
from django_hosts.resolvers import reverse
from .base import APISubdomainTestCase
from ..models import Role, User
+from ..models.bot.metricity import NotFound
class UnauthedUserAPITests(APISubdomainTestCase):
@@ -170,3 +173,58 @@ class UserModelTests(APISubdomainTestCase):
def test_correct_username_formatting(self):
"""Tests the username property with both name and discriminator formatted together."""
self.assertEqual(self.user_with_roles.username, "Test User with two roles#0001")
+
+
+class UserMetricityTests(APISubdomainTestCase):
+ @classmethod
+ def setUpTestData(cls):
+ User.objects.create(
+ id=0,
+ name="Test user",
+ discriminator=1,
+ in_guild=True,
+ )
+
+ def test_get_metricity_data(self):
+ # Given
+ verified_at = "foo"
+ total_messages = 1
+ self.mock_metricity_user(verified_at, total_messages)
+
+ # When
+ url = reverse('bot:user-metricity-data', args=[0], host='api')
+ response = self.client.get(url)
+
+ # Then
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.json(), {
+ "verified_at": verified_at,
+ "total_messages": total_messages,
+ })
+
+ def test_no_metricity_user(self):
+ # Given
+ self.mock_no_metricity_user()
+
+ # When
+ url = reverse('bot:user-metricity-data', args=[0], host='api')
+ response = self.client.get(url)
+
+ # Then
+ self.assertEqual(response.status_code, 404)
+
+ def mock_metricity_user(self, verified_at, total_messages):
+ patcher = patch("pydis_site.apps.api.viewsets.bot.user.Metricity")
+ self.metricity = patcher.start()
+ self.addCleanup(patcher.stop)
+ self.metricity = self.metricity.return_value.__enter__.return_value
+ self.metricity.user.return_value = dict(verified_at=verified_at)
+ self.metricity.total_messages.return_value = total_messages
+
+ def mock_no_metricity_user(self):
+ patcher = patch("pydis_site.apps.api.viewsets.bot.user.Metricity")
+ self.metricity = patcher.start()
+ self.addCleanup(patcher.stop)
+ self.metricity = self.metricity.return_value.__enter__.return_value
+ self.metricity.user.side_effect = NotFound()
+ self.metricity.total_messages.side_effect = NotFound()
diff --git a/pydis_site/apps/api/viewsets/bot/user.py b/pydis_site/apps/api/viewsets/bot/user.py
index 1b1af841..352d77c0 100644
--- a/pydis_site/apps/api/viewsets/bot/user.py
+++ b/pydis_site/apps/api/viewsets/bot/user.py
@@ -1,4 +1,3 @@
-from django.db import connections
from rest_framework import status
from rest_framework.decorators import action
from rest_framework.request import Request
@@ -6,6 +5,7 @@ from rest_framework.response import Response
from rest_framework.viewsets import ModelViewSet
from rest_framework_bulk import BulkCreateModelMixin
+from pydis_site.apps.api.models.bot.metricity import Metricity, NotFound
from pydis_site.apps.api.models.bot.user import User
from pydis_site.apps.api.serializers import UserSerializer
@@ -142,10 +142,11 @@ class UserViewSet(BulkCreateModelMixin, ModelViewSet):
def metricity_data(self, request: Request, pk: str = None) -> Response:
"""Request handler for metricity_data endpoint."""
user = self.get_object()
- with connections['metricity'].cursor() as cursor:
- data = {}
- cursor.execute("SELECT verified_at FROM users WHERE id = '%s'", [user.id])
- data["verified_at"], = cursor.fetchone()
- cursor.execute("SELECT COUNT(*) FROM messages WHERE author_id = '%s'", [user.id])
- data["total_messages"], = cursor.fetchone()
- return Response(data, status=status.HTTP_200_OK)
+ with Metricity() as metricity:
+ try:
+ data = metricity.user(user.id)
+ data["total_messages"] = metricity.total_messages(user.id)
+ return Response(data, status=status.HTTP_200_OK)
+ except NotFound:
+ return Response(dict(detail="User not found in metricity"),
+ status=status.HTTP_404_NOT_FOUND)