aboutsummaryrefslogtreecommitdiffstats
path: root/pydis_site/apps/api
diff options
context:
space:
mode:
Diffstat (limited to 'pydis_site/apps/api')
-rw-r--r--pydis_site/apps/api/__init__.py1
-rw-r--r--pydis_site/apps/api/github_utils.py207
-rw-r--r--pydis_site/apps/api/migrations/0013_specialsnake_image.py3
-rw-r--r--pydis_site/apps/api/migrations/0019_deletedmessage.py2
-rw-r--r--pydis_site/apps/api/migrations/0051_allow_blank_message_embeds.py3
-rw-r--r--pydis_site/apps/api/migrations/0077_use_generic_jsonfield.py3
-rw-r--r--pydis_site/apps/api/migrations/0082_otn_allow_big_solidus.py19
-rw-r--r--pydis_site/apps/api/migrations/0083_remove_embed_validation.py19
-rw-r--r--pydis_site/apps/api/migrations/0084_infraction_last_applied.py26
-rw-r--r--pydis_site/apps/api/models/bot/infraction.py6
-rw-r--r--pydis_site/apps/api/models/bot/message.py16
-rw-r--r--pydis_site/apps/api/models/bot/metricity.py28
-rw-r--r--pydis_site/apps/api/models/bot/off_topic_channel_name.py2
-rw-r--r--pydis_site/apps/api/models/utils.py172
-rw-r--r--pydis_site/apps/api/pagination.py5
-rw-r--r--pydis_site/apps/api/serializers.py1
-rw-r--r--pydis_site/apps/api/tests/migrations/__init__.py1
-rw-r--r--pydis_site/apps/api/tests/migrations/base.py102
-rw-r--r--pydis_site/apps/api/tests/migrations/test_active_infraction_migration.py496
-rw-r--r--pydis_site/apps/api/tests/migrations/test_base.py135
-rw-r--r--pydis_site/apps/api/tests/test_filterlists.py4
-rw-r--r--pydis_site/apps/api/tests/test_github_utils.py286
-rw-r--r--pydis_site/apps/api/tests/test_infractions.py15
-rw-r--r--pydis_site/apps/api/tests/test_models.py12
-rw-r--r--pydis_site/apps/api/tests/test_users.py84
-rw-r--r--pydis_site/apps/api/tests/test_validators.py229
-rw-r--r--pydis_site/apps/api/urls.py9
-rw-r--r--pydis_site/apps/api/views.py96
-rw-r--r--pydis_site/apps/api/viewsets/bot/aoc_completionist_block.py2
-rw-r--r--pydis_site/apps/api/viewsets/bot/aoc_link.py2
-rw-r--r--pydis_site/apps/api/viewsets/bot/infraction.py19
-rw-r--r--pydis_site/apps/api/viewsets/bot/nomination.py2
-rw-r--r--pydis_site/apps/api/viewsets/bot/reminder.py2
-rw-r--r--pydis_site/apps/api/viewsets/bot/user.py59
34 files changed, 860 insertions, 1208 deletions
diff --git a/pydis_site/apps/api/__init__.py b/pydis_site/apps/api/__init__.py
index afa5b4d5..e69de29b 100644
--- a/pydis_site/apps/api/__init__.py
+++ b/pydis_site/apps/api/__init__.py
@@ -1 +0,0 @@
-default_app_config = 'pydis_site.apps.api.apps.ApiConfig'
diff --git a/pydis_site/apps/api/github_utils.py b/pydis_site/apps/api/github_utils.py
new file mode 100644
index 00000000..44c571c3
--- /dev/null
+++ b/pydis_site/apps/api/github_utils.py
@@ -0,0 +1,207 @@
+"""Utilities for working with the GitHub API."""
+import dataclasses
+import datetime
+import math
+import typing
+
+import httpx
+import jwt
+
+from pydis_site import settings
+
+MAX_RUN_TIME = datetime.timedelta(minutes=10)
+"""The maximum time allowed before an action is declared timed out."""
+
+
+class ArtifactProcessingError(Exception):
+ """Base exception for other errors related to processing a GitHub artifact."""
+
+ status: int
+
+
+class UnauthorizedError(ArtifactProcessingError):
+ """The application does not have permission to access the requested repo."""
+
+ status = 401
+
+
+class NotFoundError(ArtifactProcessingError):
+ """The requested resource could not be found."""
+
+ status = 404
+
+
+class ActionFailedError(ArtifactProcessingError):
+ """The requested workflow did not conclude successfully."""
+
+ status = 400
+
+
+class RunTimeoutError(ArtifactProcessingError):
+ """The requested workflow run was not ready in time."""
+
+ status = 408
+
+
+class RunPendingError(ArtifactProcessingError):
+ """The requested workflow run is still pending, try again later."""
+
+ status = 202
+
+
[email protected](frozen=True)
+class WorkflowRun:
+ """
+ A workflow run from the GitHub API.
+
+ https://docs.github.com/en/rest/actions/workflow-runs#get-a-workflow-run
+ """
+
+ name: str
+ head_sha: str
+ created_at: str
+ status: str
+ conclusion: str
+ artifacts_url: str
+
+ @classmethod
+ def from_raw(cls, data: dict[str, typing.Any]):
+ """Create an instance using the raw data from the API, discarding unused fields."""
+ return cls(**{
+ key.name: data[key.name] for key in dataclasses.fields(cls)
+ })
+
+
+def generate_token() -> str:
+ """
+ Generate a JWT token to access the GitHub API.
+
+ The token is valid for roughly 10 minutes after generation, before the API starts
+ returning 401s.
+
+ Refer to:
+ https://docs.github.com/en/developers/apps/building-github-apps/authenticating-with-github-apps#authenticating-as-a-github-app
+ """
+ now = datetime.datetime.now()
+ return jwt.encode(
+ {
+ "iat": math.floor((now - datetime.timedelta(seconds=60)).timestamp()), # Issued at
+ "exp": math.floor((now + datetime.timedelta(minutes=9)).timestamp()), # Expires at
+ "iss": settings.GITHUB_APP_ID,
+ },
+ settings.GITHUB_APP_KEY,
+ algorithm="RS256"
+ )
+
+
+def authorize(owner: str, repo: str) -> httpx.Client:
+ """
+ Get an access token for the requested repository.
+
+ The process is roughly:
+ - GET app/installations to get a list of all app installations
+ - POST <app_access_token> to get a token to access the given app
+ - GET installation/repositories and check if the requested one is part of those
+ """
+ client = httpx.Client(
+ base_url=settings.GITHUB_API,
+ headers={"Authorization": f"bearer {generate_token()}"},
+ timeout=10,
+ )
+
+ try:
+ # Get a list of app installations we have access to
+ apps = client.get("app/installations")
+ apps.raise_for_status()
+
+ for app in apps.json():
+ # Look for an installation with the right owner
+ if app["account"]["login"] != owner:
+ continue
+
+ # Get the repositories of the specified owner
+ app_token = client.post(app["access_tokens_url"])
+ app_token.raise_for_status()
+ client.headers["Authorization"] = f"bearer {app_token.json()['token']}"
+
+ repos = client.get("installation/repositories")
+ repos.raise_for_status()
+
+ # Search for the request repository
+ for accessible_repo in repos.json()["repositories"]:
+ if accessible_repo["name"] == repo:
+ # We've found the correct repository, and it's accessible with the current auth
+ return client
+
+ raise NotFoundError(
+ "Could not find the requested repository. Make sure the application can access it."
+ )
+
+ except BaseException as e:
+ # Close the client if we encountered an unexpected exception
+ client.close()
+ raise e
+
+
+def check_run_status(run: WorkflowRun) -> str:
+ """Check if the provided run has been completed, otherwise raise an exception."""
+ created_at = datetime.datetime.strptime(run.created_at, settings.GITHUB_TIMESTAMP_FORMAT)
+ run_time = datetime.datetime.utcnow() - created_at
+
+ if run.status != "completed":
+ if run_time <= MAX_RUN_TIME:
+ raise RunPendingError(
+ f"The requested run is still pending. It was created "
+ f"{run_time.seconds // 60}:{run_time.seconds % 60 :>02} minutes ago."
+ )
+ else:
+ raise RunTimeoutError("The requested workflow was not ready in time.")
+
+ if run.conclusion != "success":
+ # The action failed, or did not run
+ raise ActionFailedError(f"The requested workflow ended with: {run.conclusion}")
+
+ # The requested action is ready
+ return run.artifacts_url
+
+
+def get_artifact(owner: str, repo: str, sha: str, action_name: str, artifact_name: str) -> str:
+ """Get a download URL for a build artifact."""
+ client = authorize(owner, repo)
+
+ try:
+ # Get the workflow runs for this repository
+ runs = client.get(f"/repos/{owner}/{repo}/actions/runs", params={"per_page": 100})
+ runs.raise_for_status()
+ runs = runs.json()
+
+ # Filter the runs for the one associated with the given SHA
+ for run in runs["workflow_runs"]:
+ run = WorkflowRun.from_raw(run)
+ if run.name == action_name and sha == run.head_sha:
+ break
+ else:
+ raise NotFoundError(
+ "Could not find a run matching the provided settings in the previous hundred runs."
+ )
+
+ # Check the workflow status
+ url = check_run_status(run)
+
+ # Filter the artifacts, and return the download URL
+ artifacts = client.get(url)
+ artifacts.raise_for_status()
+
+ for artifact in artifacts.json()["artifacts"]:
+ if artifact["name"] == artifact_name:
+ data = client.get(artifact["archive_download_url"])
+ if data.status_code == 302:
+ return str(data.next_request.url)
+
+ # The following line is left untested since it should in theory be impossible
+ data.raise_for_status() # pragma: no cover
+
+ raise NotFoundError("Could not find an artifact matching the provided name.")
+
+ finally:
+ client.close()
diff --git a/pydis_site/apps/api/migrations/0013_specialsnake_image.py b/pydis_site/apps/api/migrations/0013_specialsnake_image.py
index a0d0d318..8ba3432f 100644
--- a/pydis_site/apps/api/migrations/0013_specialsnake_image.py
+++ b/pydis_site/apps/api/migrations/0013_specialsnake_image.py
@@ -2,7 +2,6 @@
import datetime
from django.db import migrations, models
-from django.utils.timezone import utc
class Migration(migrations.Migration):
@@ -15,7 +14,7 @@ class Migration(migrations.Migration):
migrations.AddField(
model_name='specialsnake',
name='image',
- field=models.URLField(default=datetime.datetime(2018, 10, 23, 11, 51, 23, 703868, tzinfo=utc)),
+ field=models.URLField(default=datetime.datetime(2018, 10, 23, 11, 51, 23, 703868, tzinfo=datetime.timezone.utc)),
preserve_default=False,
),
]
diff --git a/pydis_site/apps/api/migrations/0019_deletedmessage.py b/pydis_site/apps/api/migrations/0019_deletedmessage.py
index 6b848d64..25d04434 100644
--- a/pydis_site/apps/api/migrations/0019_deletedmessage.py
+++ b/pydis_site/apps/api/migrations/0019_deletedmessage.py
@@ -18,7 +18,7 @@ class Migration(migrations.Migration):
('id', models.BigIntegerField(help_text='The message ID as taken from Discord.', primary_key=True, serialize=False, validators=[django.core.validators.MinValueValidator(limit_value=0, message='Message IDs cannot be negative.')])),
('channel_id', models.BigIntegerField(help_text='The channel ID that this message was sent in, taken from Discord.', validators=[django.core.validators.MinValueValidator(limit_value=0, message='Channel IDs cannot be negative.')])),
('content', models.CharField(help_text='The content of this message, taken from Discord.', max_length=2000)),
- ('embeds', django.contrib.postgres.fields.ArrayField(base_field=django.contrib.postgres.fields.jsonb.JSONField(validators=[pydis_site.apps.api.models.utils.validate_embed]), help_text='Embeds attached to this message.', size=None)),
+ ('embeds', django.contrib.postgres.fields.ArrayField(base_field=django.contrib.postgres.fields.jsonb.JSONField(validators=[]), help_text='Embeds attached to this message.', size=None)),
('author', models.ForeignKey(help_text='The author of this message.', on_delete=django.db.models.deletion.CASCADE, to='api.User')),
('deletion_context', models.ForeignKey(help_text='The deletion context this message is part of.', on_delete=django.db.models.deletion.CASCADE, to='api.MessageDeletionContext')),
],
diff --git a/pydis_site/apps/api/migrations/0051_allow_blank_message_embeds.py b/pydis_site/apps/api/migrations/0051_allow_blank_message_embeds.py
index 124c6a57..622f21d1 100644
--- a/pydis_site/apps/api/migrations/0051_allow_blank_message_embeds.py
+++ b/pydis_site/apps/api/migrations/0051_allow_blank_message_embeds.py
@@ -3,7 +3,6 @@
import django.contrib.postgres.fields
import django.contrib.postgres.fields.jsonb
from django.db import migrations
-import pydis_site.apps.api.models.utils
class Migration(migrations.Migration):
@@ -16,6 +15,6 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name='deletedmessage',
name='embeds',
- field=django.contrib.postgres.fields.ArrayField(base_field=django.contrib.postgres.fields.jsonb.JSONField(validators=[pydis_site.apps.api.models.utils.validate_embed]), blank=True, help_text='Embeds attached to this message.', size=None),
+ field=django.contrib.postgres.fields.ArrayField(base_field=django.contrib.postgres.fields.jsonb.JSONField(validators=[]), blank=True, help_text='Embeds attached to this message.', size=None),
),
]
diff --git a/pydis_site/apps/api/migrations/0077_use_generic_jsonfield.py b/pydis_site/apps/api/migrations/0077_use_generic_jsonfield.py
index 9e8f2fb9..95ef5850 100644
--- a/pydis_site/apps/api/migrations/0077_use_generic_jsonfield.py
+++ b/pydis_site/apps/api/migrations/0077_use_generic_jsonfield.py
@@ -2,7 +2,6 @@
import django.contrib.postgres.fields
from django.db import migrations, models
-import pydis_site.apps.api.models.utils
class Migration(migrations.Migration):
@@ -20,6 +19,6 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name='deletedmessage',
name='embeds',
- field=django.contrib.postgres.fields.ArrayField(base_field=models.JSONField(validators=[pydis_site.apps.api.models.utils.validate_embed]), blank=True, help_text='Embeds attached to this message.', size=None),
+ field=django.contrib.postgres.fields.ArrayField(base_field=models.JSONField(validators=[]), blank=True, help_text='Embeds attached to this message.', size=None),
),
]
diff --git a/pydis_site/apps/api/migrations/0082_otn_allow_big_solidus.py b/pydis_site/apps/api/migrations/0082_otn_allow_big_solidus.py
new file mode 100644
index 00000000..abbb98ec
--- /dev/null
+++ b/pydis_site/apps/api/migrations/0082_otn_allow_big_solidus.py
@@ -0,0 +1,19 @@
+# Generated by Django 3.1.14 on 2022-04-21 23:29
+
+import django.core.validators
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('api', '0081_bumpedthread'),
+ ]
+
+ operations = [
+ migrations.AlterField(
+ model_name='offtopicchannelname',
+ name='name',
+ field=models.CharField(help_text='The actual channel name that will be used on our Discord server.', max_length=96, primary_key=True, serialize=False, validators=[django.core.validators.RegexValidator(regex="^[a-z0-9\\U0001d5a0-\\U0001d5b9-ǃ?’'<>⧹⧸]+$")]),
+ ),
+ ]
diff --git a/pydis_site/apps/api/migrations/0083_remove_embed_validation.py b/pydis_site/apps/api/migrations/0083_remove_embed_validation.py
new file mode 100644
index 00000000..e835bb66
--- /dev/null
+++ b/pydis_site/apps/api/migrations/0083_remove_embed_validation.py
@@ -0,0 +1,19 @@
+# Generated by Django 3.1.14 on 2022-06-30 09:41
+
+import django.contrib.postgres.fields
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('api', '0082_otn_allow_big_solidus'),
+ ]
+
+ operations = [
+ migrations.AlterField(
+ model_name='deletedmessage',
+ name='embeds',
+ field=django.contrib.postgres.fields.ArrayField(base_field=models.JSONField(), blank=True, help_text='Embeds attached to this message.', size=None),
+ ),
+ ]
diff --git a/pydis_site/apps/api/migrations/0084_infraction_last_applied.py b/pydis_site/apps/api/migrations/0084_infraction_last_applied.py
new file mode 100644
index 00000000..7704ddb8
--- /dev/null
+++ b/pydis_site/apps/api/migrations/0084_infraction_last_applied.py
@@ -0,0 +1,26 @@
+# Generated by Django 4.0.6 on 2022-07-27 20:32
+
+import django.utils.timezone
+from django.db import migrations, models
+from django.apps.registry import Apps
+
+
+def set_last_applied_to_inserted_at(apps: Apps, schema_editor):
+ Infractions = apps.get_model("api", "infraction")
+ Infractions.objects.all().update(last_applied=models.F("inserted_at"))
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('api', '0083_remove_embed_validation'),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name='infraction',
+ name='last_applied',
+ field=models.DateTimeField(default=django.utils.timezone.now, help_text='The date and time of when this infraction was last applied.'),
+ ),
+ migrations.RunPython(set_last_applied_to_inserted_at)
+ ]
diff --git a/pydis_site/apps/api/models/bot/infraction.py b/pydis_site/apps/api/models/bot/infraction.py
index c9303024..218ee5ec 100644
--- a/pydis_site/apps/api/models/bot/infraction.py
+++ b/pydis_site/apps/api/models/bot/infraction.py
@@ -23,6 +23,12 @@ class Infraction(ModelReprMixin, models.Model):
default=timezone.now,
help_text="The date and time of the creation of this infraction."
)
+ last_applied = models.DateTimeField(
+ # This default is for backwards compatibility with bot versions
+ # that don't explicitly give a value.
+ default=timezone.now,
+ help_text="The date and time of when this infraction was last applied."
+ )
expires_at = models.DateTimeField(
null=True,
help_text=(
diff --git a/pydis_site/apps/api/models/bot/message.py b/pydis_site/apps/api/models/bot/message.py
index bab3368d..89ae27e4 100644
--- a/pydis_site/apps/api/models/bot/message.py
+++ b/pydis_site/apps/api/models/bot/message.py
@@ -1,13 +1,11 @@
-from datetime import datetime
+import datetime
from django.contrib.postgres import fields as pgfields
from django.core.validators import MinValueValidator
from django.db import models
-from django.utils import timezone
from pydis_site.apps.api.models.bot.user import User
from pydis_site.apps.api.models.mixins import ModelReprMixin
-from pydis_site.apps.api.models.utils import validate_embed
class Message(ModelReprMixin, models.Model):
@@ -48,9 +46,7 @@ class Message(ModelReprMixin, models.Model):
blank=True
)
embeds = pgfields.ArrayField(
- models.JSONField(
- validators=(validate_embed,)
- ),
+ models.JSONField(),
blank=True,
help_text="Embeds attached to this message."
)
@@ -63,11 +59,11 @@ class Message(ModelReprMixin, models.Model):
)
@property
- def timestamp(self) -> datetime:
+ def timestamp(self) -> datetime.datetime:
"""Attribute that represents the message timestamp as derived from the snowflake id."""
- tz_naive_datetime = datetime.utcfromtimestamp(((self.id >> 22) + 1420070400000) / 1000)
- tz_aware_datetime = timezone.make_aware(tz_naive_datetime, timezone=timezone.utc)
- return tz_aware_datetime
+ return datetime.datetime.utcfromtimestamp(
+ ((self.id >> 22) + 1420070400000) / 1000
+ ).replace(tzinfo=datetime.timezone.utc)
class Meta:
"""Metadata provided for Django's ORM."""
diff --git a/pydis_site/apps/api/models/bot/metricity.py b/pydis_site/apps/api/models/bot/metricity.py
index abd25ef0..f53dd33c 100644
--- a/pydis_site/apps/api/models/bot/metricity.py
+++ b/pydis_site/apps/api/models/bot/metricity.py
@@ -130,3 +130,31 @@ class Metricity:
raise NotFoundError()
return values
+
+ def total_messages_in_past_n_days(
+ self,
+ user_ids: list[str],
+ days: int
+ ) -> list[tuple[str, int]]:
+ """
+ Query activity by a list of users in the past `days` days.
+
+ Returns a list of (user_id, message_count) tuples.
+ """
+ self.cursor.execute(
+ """
+ SELECT
+ author_id, COUNT(*)
+ FROM messages
+ WHERE
+ author_id IN %s
+ AND NOT is_deleted
+ AND channel_id NOT IN %s
+ AND created_at > now() - interval '%s days'
+ GROUP BY author_id
+ """,
+ [tuple(user_ids), EXCLUDE_CHANNELS, days]
+ )
+ values = self.cursor.fetchall()
+
+ return values
diff --git a/pydis_site/apps/api/models/bot/off_topic_channel_name.py b/pydis_site/apps/api/models/bot/off_topic_channel_name.py
index e9fec114..b380efad 100644
--- a/pydis_site/apps/api/models/bot/off_topic_channel_name.py
+++ b/pydis_site/apps/api/models/bot/off_topic_channel_name.py
@@ -11,7 +11,7 @@ class OffTopicChannelName(ModelReprMixin, models.Model):
primary_key=True,
max_length=96,
validators=(
- RegexValidator(regex=r"^[a-z0-9\U0001d5a0-\U0001d5b9-ǃ?’'<>]+$"),
+ RegexValidator(regex=r"^[a-z0-9\U0001d5a0-\U0001d5b9-ǃ?’'<>⧹⧸]+$"),
),
help_text="The actual channel name that will be used on our Discord server."
)
diff --git a/pydis_site/apps/api/models/utils.py b/pydis_site/apps/api/models/utils.py
deleted file mode 100644
index 859394d2..00000000
--- a/pydis_site/apps/api/models/utils.py
+++ /dev/null
@@ -1,172 +0,0 @@
-from collections.abc import Mapping
-from typing import Any, Dict
-
-from django.core.exceptions import ValidationError
-from django.core.validators import MaxLengthValidator, MinLengthValidator
-
-
-def is_bool_validator(value: Any) -> None:
- """Validates if a given value is of type bool."""
- if not isinstance(value, bool):
- raise ValidationError(f"This field must be of type bool, not {type(value)}.")
-
-
-def validate_embed_fields(fields: dict) -> None:
- """Raises a ValidationError if any of the given embed fields is invalid."""
- field_validators = {
- 'name': (MaxLengthValidator(limit_value=256),),
- 'value': (MaxLengthValidator(limit_value=1024),),
- 'inline': (is_bool_validator,),
- }
-
- required_fields = ('name', 'value')
-
- for field in fields:
- if not isinstance(field, Mapping):
- raise ValidationError("Embed fields must be a mapping.")
-
- if not all(required_field in field for required_field in required_fields):
- raise ValidationError(
- f"Embed fields must contain the following fields: {', '.join(required_fields)}."
- )
-
- for field_name, value in field.items():
- if field_name not in field_validators:
- raise ValidationError(f"Unknown embed field field: {field_name!r}.")
-
- for validator in field_validators[field_name]:
- validator(value)
-
-
-def validate_embed_footer(footer: Dict[str, str]) -> None:
- """Raises a ValidationError if the given footer is invalid."""
- field_validators = {
- 'text': (
- MinLengthValidator(
- limit_value=1,
- message="Footer text must not be empty."
- ),
- MaxLengthValidator(limit_value=2048)
- ),
- 'icon_url': (),
- 'proxy_icon_url': ()
- }
-
- if not isinstance(footer, Mapping):
- raise ValidationError("Embed footer must be a mapping.")
-
- for field_name, value in footer.items():
- if field_name not in field_validators:
- raise ValidationError(f"Unknown embed footer field: {field_name!r}.")
-
- for validator in field_validators[field_name]:
- validator(value)
-
-
-def validate_embed_author(author: Any) -> None:
- """Raises a ValidationError if the given author is invalid."""
- field_validators = {
- 'name': (
- MinLengthValidator(
- limit_value=1,
- message="Embed author name must not be empty."
- ),
- MaxLengthValidator(limit_value=256)
- ),
- 'url': (),
- 'icon_url': (),
- 'proxy_icon_url': ()
- }
-
- if not isinstance(author, Mapping):
- raise ValidationError("Embed author must be a mapping.")
-
- for field_name, value in author.items():
- if field_name not in field_validators:
- raise ValidationError(f"Unknown embed author field: {field_name!r}.")
-
- for validator in field_validators[field_name]:
- validator(value)
-
-
-def validate_embed(embed: Any) -> None:
- """
- Validate a JSON document containing an embed as possible to send on Discord.
-
- This attempts to rebuild the validation used by Discord
- as well as possible by checking for various embed limits so we can
- ensure that any embed we store here will also be accepted as a
- valid embed by the Discord API.
-
- Using this directly is possible, although not intended - you usually
- stick this onto the `validators` keyword argument of model fields.
-
- Example:
-
- >>> from django.db import models
- >>> from pydis_site.apps.api.models.utils import validate_embed
- >>> class MyMessage(models.Model):
- ... embed = models.JSONField(
- ... validators=(
- ... validate_embed,
- ... )
- ... )
- ... # ...
- ...
-
- Args:
- embed (Any):
- A dictionary describing the contents of this embed.
- See the official documentation for a full reference
- of accepted keys by this dictionary:
- https://discordapp.com/developers/docs/resources/channel#embed-object
-
- Raises:
- ValidationError:
- In case the given embed is deemed invalid, a `ValidationError`
- is raised which in turn will allow Django to display errors
- as appropriate.
- """
- all_keys = {
- 'title', 'type', 'description', 'url', 'timestamp',
- 'color', 'footer', 'image', 'thumbnail', 'video',
- 'provider', 'author', 'fields'
- }
- one_required_of = {'description', 'fields', 'image', 'title', 'video'}
- field_validators = {
- 'title': (
- MinLengthValidator(
- limit_value=1,
- message="Embed title must not be empty."
- ),
- MaxLengthValidator(limit_value=256)
- ),
- 'description': (MaxLengthValidator(limit_value=4096),),
- 'fields': (
- MaxLengthValidator(limit_value=25),
- validate_embed_fields
- ),
- 'footer': (validate_embed_footer,),
- 'author': (validate_embed_author,)
- }
-
- if not embed:
- raise ValidationError("Tag embed must not be empty.")
-
- elif not isinstance(embed, Mapping):
- raise ValidationError("Tag embed must be a mapping.")
-
- elif not any(field in embed for field in one_required_of):
- raise ValidationError(f"Tag embed must contain one of the fields {one_required_of}.")
-
- for required_key in one_required_of:
- if required_key in embed and not embed[required_key]:
- raise ValidationError(f"Key {required_key!r} must not be empty.")
-
- for field_name, value in embed.items():
- if field_name not in all_keys:
- raise ValidationError(f"Unknown field name: {field_name!r}")
-
- if field_name in field_validators:
- for validator in field_validators[field_name]:
- validator(value)
diff --git a/pydis_site/apps/api/pagination.py b/pydis_site/apps/api/pagination.py
index 2a325460..61707d33 100644
--- a/pydis_site/apps/api/pagination.py
+++ b/pydis_site/apps/api/pagination.py
@@ -1,7 +1,6 @@
-import typing
-
from rest_framework.pagination import LimitOffsetPagination
from rest_framework.response import Response
+from rest_framework.utils.serializer_helpers import ReturnList
class LimitOffsetPaginationExtended(LimitOffsetPagination):
@@ -44,6 +43,6 @@ class LimitOffsetPaginationExtended(LimitOffsetPagination):
default_limit = 100
- def get_paginated_response(self, data: typing.Any) -> Response:
+ def get_paginated_response(self, data: ReturnList) -> Response:
"""Override to skip metadata i.e. `count`, `next`, and `previous`."""
return Response(data)
diff --git a/pydis_site/apps/api/serializers.py b/pydis_site/apps/api/serializers.py
index e53ccffa..9228c1f4 100644
--- a/pydis_site/apps/api/serializers.py
+++ b/pydis_site/apps/api/serializers.py
@@ -176,6 +176,7 @@ class InfractionSerializer(ModelSerializer):
fields = (
'id',
'inserted_at',
+ 'last_applied',
'expires_at',
'active',
'user',
diff --git a/pydis_site/apps/api/tests/migrations/__init__.py b/pydis_site/apps/api/tests/migrations/__init__.py
deleted file mode 100644
index 38e42ffc..00000000
--- a/pydis_site/apps/api/tests/migrations/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-"""This submodule contains tests for functions used in data migrations."""
diff --git a/pydis_site/apps/api/tests/migrations/base.py b/pydis_site/apps/api/tests/migrations/base.py
deleted file mode 100644
index 0c0a5bd0..00000000
--- a/pydis_site/apps/api/tests/migrations/base.py
+++ /dev/null
@@ -1,102 +0,0 @@
-"""Includes utilities for testing migrations."""
-from django.db import connection
-from django.db.migrations.executor import MigrationExecutor
-from django.test import TestCase
-
-
-class MigrationsTestCase(TestCase):
- """
- A `TestCase` subclass to test migration files.
-
- To be able to properly test a migration, we will need to inject data into the test database
- before the migrations we want to test are applied, but after the older migrations have been
- applied. This makes sure that we are testing "as if" we were actually applying this migration
- to a database in the state it was in before introducing the new migration.
-
- To set up a MigrationsTestCase, create a subclass of this class and set the following
- class-level attributes:
-
- - app: The name of the app that contains the migrations (e.g., `'api'`)
- - migration_prior: The name* of the last migration file before the migrations you want to test
- - migration_target: The name* of the last migration file we want to test
-
- *) Specify the file names without a path or the `.py` file extension.
-
- Additionally, overwrite the `setUpMigrationData` in the subclass to inject data into the
- database before the migrations we want to test are applied. Please read the docstring of the
- method for more information. An optional hook, `setUpPostMigrationData` is also provided.
- """
-
- # These class-level attributes should be set in classes that inherit from this base class.
- app = None
- migration_prior = None
- migration_target = None
-
- @classmethod
- def setUpTestData(cls):
- """
- Injects data into the test database prior to the migration we're trying to test.
-
- This class methods reverts the test database back to the state of the last migration file
- prior to the migrations we want to test. It will then allow the user to inject data into the
- test database by calling the `setUpMigrationData` hook. After the data has been injected, it
- will apply the migrations we want to test and call the `setUpPostMigrationData` hook. The
- user can now test if the migration correctly migrated the injected test data.
- """
- if not cls.app:
- raise ValueError("The `app` attribute was not set.")
-
- if not cls.migration_prior or not cls.migration_target:
- raise ValueError("Both ` migration_prior` and `migration_target` need to be set.")
-
- cls.migrate_from = [(cls.app, cls.migration_prior)]
- cls.migrate_to = [(cls.app, cls.migration_target)]
-
- # Reverse to database state prior to the migrations we want to test
- executor = MigrationExecutor(connection)
- executor.migrate(cls.migrate_from)
-
- # Call the data injection hook with the current state of the project
- old_apps = executor.loader.project_state(cls.migrate_from).apps
- cls.setUpMigrationData(old_apps)
-
- # Run the migrations we want to test
- executor = MigrationExecutor(connection)
- executor.loader.build_graph()
- executor.migrate(cls.migrate_to)
-
- # Save the project state so we're able to work with the correct model states
- cls.apps = executor.loader.project_state(cls.migrate_to).apps
-
- # Call `setUpPostMigrationData` to potentially set up post migration data used in testing
- cls.setUpPostMigrationData(cls.apps)
-
- @classmethod
- def setUpMigrationData(cls, apps):
- """
- Override this method to inject data into the test database before the migration is applied.
-
- This method will be called after setting up the database according to the migrations that
- come before the migration(s) we are trying to test, but before the to-be-tested migration(s)
- are applied. This allows us to simulate a database state just prior to the migrations we are
- trying to test.
-
- To make sure we're creating objects according to the state the models were in at this point
- in the migration history, use `apps.get_model(app_name: str, model_name: str)` to get the
- appropriate model, e.g.:
-
- >>> Infraction = apps.get_model('api', 'Infraction')
- """
- pass
-
- @classmethod
- def setUpPostMigrationData(cls, apps):
- """
- Set up additional test data after the target migration has been applied.
-
- Use `apps.get_model(app_name: str, model_name: str)` to get the correct instances of the
- model classes:
-
- >>> Infraction = apps.get_model('api', 'Infraction')
- """
- pass
diff --git a/pydis_site/apps/api/tests/migrations/test_active_infraction_migration.py b/pydis_site/apps/api/tests/migrations/test_active_infraction_migration.py
deleted file mode 100644
index 8dc29b34..00000000
--- a/pydis_site/apps/api/tests/migrations/test_active_infraction_migration.py
+++ /dev/null
@@ -1,496 +0,0 @@
-"""Tests for the data migration in `filename`."""
-import logging
-from collections import ChainMap, namedtuple
-from datetime import timedelta
-from itertools import count
-from typing import Dict, Iterable, Type, Union
-
-from django.db.models import Q
-from django.forms.models import model_to_dict
-from django.utils import timezone
-
-from pydis_site.apps.api.models import Infraction, User
-from .base import MigrationsTestCase
-
-log = logging.getLogger(__name__)
-log.setLevel(logging.DEBUG)
-
-
-InfractionHistory = namedtuple('InfractionHistory', ("user_id", "infraction_history"))
-
-
-class InfractionFactory:
- """Factory that creates infractions for a User instance."""
-
- infraction_id = count(1)
- user_id = count(1)
- default_values = {
- 'active': True,
- 'expires_at': None,
- 'hidden': False,
- }
-
- @classmethod
- def create(
- cls,
- actor: User,
- infractions: Iterable[Dict[str, Union[str, int, bool]]],
- infraction_model: Type[Infraction] = Infraction,
- user_model: Type[User] = User,
- ) -> InfractionHistory:
- """
- Creates `infractions` for the `user` with the given `actor`.
-
- The `infractions` dictionary can contain the following fields:
- - `type` (required)
- - `active` (default: True)
- - `expires_at` (default: None; i.e, permanent)
- - `hidden` (default: False).
-
- The parameters `infraction_model` and `user_model` can be used to pass in an instance of
- both model classes from a different migration/project state.
- """
- user_id = next(cls.user_id)
- user = user_model.objects.create(
- id=user_id,
- name=f"Infracted user {user_id}",
- discriminator=user_id,
- avatar_hash=None,
- )
- infraction_history = []
-
- for infraction in infractions:
- infraction = dict(infraction)
- infraction["id"] = next(cls.infraction_id)
- infraction = ChainMap(infraction, cls.default_values)
- new_infraction = infraction_model.objects.create(
- user=user,
- actor=actor,
- type=infraction["type"],
- reason=f"`{infraction['type']}` infraction (ID: {infraction['id']} of {user}",
- active=infraction['active'],
- hidden=infraction['hidden'],
- expires_at=infraction['expires_at'],
- )
- infraction_history.append(new_infraction)
-
- return InfractionHistory(user_id=user_id, infraction_history=infraction_history)
-
-
-class InfractionFactoryTests(MigrationsTestCase):
- """Tests for the InfractionFactory."""
-
- app = "api"
- migration_prior = "0046_reminder_jump_url"
- migration_target = "0046_reminder_jump_url"
-
- @classmethod
- def setUpPostMigrationData(cls, apps):
- """Create a default actor for all infractions."""
- cls.infraction_model = apps.get_model('api', 'Infraction')
- cls.user_model = apps.get_model('api', 'User')
-
- cls.actor = cls.user_model.objects.create(
- id=9999,
- name="Unknown Moderator",
- discriminator=1040,
- avatar_hash=None,
- )
-
- def test_infraction_factory_total_count(self):
- """Does the test database hold as many infractions as we tried to create?"""
- InfractionFactory.create(
- actor=self.actor,
- infractions=(
- {'type': 'kick', 'active': False, 'hidden': False},
- {'type': 'ban', 'active': True, 'hidden': False},
- {'type': 'note', 'active': False, 'hidden': True},
- ),
- infraction_model=self.infraction_model,
- user_model=self.user_model,
- )
- database_count = Infraction.objects.all().count()
- self.assertEqual(3, database_count)
-
- def test_infraction_factory_multiple_users(self):
- """Does the test database hold as many infractions as we tried to create?"""
- for _user in range(5):
- InfractionFactory.create(
- actor=self.actor,
- infractions=(
- {'type': 'kick', 'active': False, 'hidden': True},
- {'type': 'ban', 'active': True, 'hidden': False},
- ),
- infraction_model=self.infraction_model,
- user_model=self.user_model,
- )
-
- # Check if infractions and users are recorded properly in the database
- database_count = Infraction.objects.all().count()
- self.assertEqual(database_count, 10)
-
- user_count = User.objects.all().count()
- self.assertEqual(user_count, 5 + 1)
-
- def test_infraction_factory_sets_correct_fields(self):
- """Does the InfractionFactory set the correct attributes?"""
- infractions = (
- {
- 'type': 'note',
- 'active': False,
- 'hidden': True,
- 'expires_at': timezone.now()
- },
- {'type': 'warning', 'active': False, 'hidden': False, 'expires_at': None},
- {'type': 'watch', 'active': False, 'hidden': True, 'expires_at': None},
- {'type': 'mute', 'active': True, 'hidden': False, 'expires_at': None},
- {'type': 'kick', 'active': True, 'hidden': True, 'expires_at': None},
- {'type': 'ban', 'active': True, 'hidden': False, 'expires_at': None},
- {
- 'type': 'superstar',
- 'active': True,
- 'hidden': True,
- 'expires_at': timezone.now()
- },
- )
-
- InfractionFactory.create(
- actor=self.actor,
- infractions=infractions,
- infraction_model=self.infraction_model,
- user_model=self.user_model,
- )
-
- for infraction in infractions:
- with self.subTest(**infraction):
- self.assertTrue(Infraction.objects.filter(**infraction).exists())
-
-
-class ActiveInfractionMigrationTests(MigrationsTestCase):
- """
- Tests the active infraction data migration.
-
- The active infraction data migration should do the following things:
-
- 1. migrates all active notes, warnings, and kicks to an inactive status;
- 2. migrates all users with multiple active infractions of a single type to have only one active
- infraction of that type. The infraction with the longest duration stays active.
- """
-
- app = "api"
- migration_prior = "0046_reminder_jump_url"
- migration_target = "0047_active_infractions_migration"
-
- @classmethod
- def setUpMigrationData(cls, apps):
- """Sets up an initial database state that contains the relevant test cases."""
- # Fetch the Infraction and User model in the current migration state
- cls.infraction_model = apps.get_model('api', 'Infraction')
- cls.user_model = apps.get_model('api', 'User')
-
- cls.created_infractions = {}
-
- # Moderator that serves as actor for all infractions
- cls.user_moderator = cls.user_model.objects.create(
- id=9999,
- name="Olivier de Vienne",
- discriminator=1040,
- avatar_hash=None,
- )
-
- # User #1: clean user with no infractions
- cls.created_infractions["no infractions"] = InfractionFactory.create(
- actor=cls.user_moderator,
- infractions=[],
- infraction_model=cls.infraction_model,
- user_model=cls.user_model,
- )
-
- # User #2: One inactive note infraction
- cls.created_infractions["one inactive note"] = InfractionFactory.create(
- actor=cls.user_moderator,
- infractions=(
- {'type': 'note', 'active': False, 'hidden': True},
- ),
- infraction_model=cls.infraction_model,
- user_model=cls.user_model,
- )
-
- # User #3: One active note infraction
- cls.created_infractions["one active note"] = InfractionFactory.create(
- actor=cls.user_moderator,
- infractions=(
- {'type': 'note', 'active': True, 'hidden': True},
- ),
- infraction_model=cls.infraction_model,
- user_model=cls.user_model,
- )
-
- # User #4: One active and one inactive note infraction
- cls.created_infractions["one active and one inactive note"] = InfractionFactory.create(
- actor=cls.user_moderator,
- infractions=(
- {'type': 'note', 'active': False, 'hidden': True},
- {'type': 'note', 'active': True, 'hidden': True},
- ),
- infraction_model=cls.infraction_model,
- user_model=cls.user_model,
- )
-
- # User #5: Once active note, one active kick, once active warning
- cls.created_infractions["active note, kick, warning"] = InfractionFactory.create(
- actor=cls.user_moderator,
- infractions=(
- {'type': 'note', 'active': True, 'hidden': True},
- {'type': 'kick', 'active': True, 'hidden': True},
- {'type': 'warning', 'active': True, 'hidden': True},
- ),
- infraction_model=cls.infraction_model,
- user_model=cls.user_model,
- )
-
- # User #6: One inactive ban and one active ban
- cls.created_infractions["one inactive and one active ban"] = InfractionFactory.create(
- actor=cls.user_moderator,
- infractions=(
- {'type': 'ban', 'active': False, 'hidden': True},
- {'type': 'ban', 'active': True, 'hidden': True},
- ),
- infraction_model=cls.infraction_model,
- user_model=cls.user_model,
- )
-
- # User #7: Two active permanent bans
- cls.created_infractions["two active perm bans"] = InfractionFactory.create(
- actor=cls.user_moderator,
- infractions=(
- {'type': 'ban', 'active': True, 'hidden': True},
- {'type': 'ban', 'active': True, 'hidden': True},
- ),
- infraction_model=cls.infraction_model,
- user_model=cls.user_model,
- )
-
- # User #8: Multiple active temporary bans
- cls.created_infractions["multiple active temp bans"] = InfractionFactory.create(
- actor=cls.user_moderator,
- infractions=(
- {
- 'type': 'ban',
- 'active': True,
- 'hidden': True,
- 'expires_at': timezone.now() + timedelta(days=1)
- },
- {
- 'type': 'ban',
- 'active': True,
- 'hidden': True,
- 'expires_at': timezone.now() + timedelta(days=10)
- },
- {
- 'type': 'ban',
- 'active': True,
- 'hidden': True,
- 'expires_at': timezone.now() + timedelta(days=20)
- },
- {
- 'type': 'ban',
- 'active': True,
- 'hidden': True,
- 'expires_at': timezone.now() + timedelta(days=5)
- },
- ),
- infraction_model=cls.infraction_model,
- user_model=cls.user_model,
- )
-
- # User #9: One active permanent ban, two active temporary bans
- cls.created_infractions["active perm, two active temp bans"] = InfractionFactory.create(
- actor=cls.user_moderator,
- infractions=(
- {
- 'type': 'ban',
- 'active': True,
- 'hidden': True,
- 'expires_at': timezone.now() + timedelta(days=10)
- },
- {
- 'type': 'ban',
- 'active': True,
- 'hidden': True,
- 'expires_at': None,
- },
- {
- 'type': 'ban',
- 'active': True,
- 'hidden': True,
- 'expires_at': timezone.now() + timedelta(days=7)
- },
- ),
- infraction_model=cls.infraction_model,
- user_model=cls.user_model,
- )
-
- # User #10: One inactive permanent ban, two active temporary bans
- cls.created_infractions["one inactive perm ban, two active temp bans"] = (
- InfractionFactory.create(
- actor=cls.user_moderator,
- infractions=(
- {
- 'type': 'ban',
- 'active': True,
- 'hidden': True,
- 'expires_at': timezone.now() + timedelta(days=10)
- },
- {
- 'type': 'ban',
- 'active': False,
- 'hidden': True,
- 'expires_at': None,
- },
- {
- 'type': 'ban',
- 'active': True,
- 'hidden': True,
- 'expires_at': timezone.now() + timedelta(days=7)
- },
- ),
- infraction_model=cls.infraction_model,
- user_model=cls.user_model,
- )
- )
-
- # User #11: Active ban, active mute, active superstar
- cls.created_infractions["active ban, mute, and superstar"] = InfractionFactory.create(
- actor=cls.user_moderator,
- infractions=(
- {'type': 'ban', 'active': True, 'hidden': True},
- {'type': 'mute', 'active': True, 'hidden': True},
- {'type': 'superstar', 'active': True, 'hidden': True},
- {'type': 'watch', 'active': True, 'hidden': True},
- ),
- infraction_model=cls.infraction_model,
- user_model=cls.user_model,
- )
-
- # User #12: Multiple active bans, active mutes, active superstars
- cls.created_infractions["multiple active bans, mutes, stars"] = InfractionFactory.create(
- actor=cls.user_moderator,
- infractions=(
- {'type': 'ban', 'active': True, 'hidden': True},
- {'type': 'ban', 'active': True, 'hidden': True},
- {'type': 'ban', 'active': True, 'hidden': True},
- {'type': 'mute', 'active': True, 'hidden': True},
- {'type': 'mute', 'active': True, 'hidden': True},
- {'type': 'mute', 'active': True, 'hidden': True},
- {'type': 'superstar', 'active': True, 'hidden': True},
- {'type': 'superstar', 'active': True, 'hidden': True},
- {'type': 'superstar', 'active': True, 'hidden': True},
- {'type': 'watch', 'active': True, 'hidden': True},
- {'type': 'watch', 'active': True, 'hidden': True},
- {'type': 'watch', 'active': True, 'hidden': True},
- ),
- infraction_model=cls.infraction_model,
- user_model=cls.user_model,
- )
-
- def test_all_never_active_types_became_inactive(self):
- """Are all infractions of a non-active type inactive after the migration?"""
- inactive_type_query = Q(type="note") | Q(type="warning") | Q(type="kick")
- self.assertFalse(
- self.infraction_model.objects.filter(inactive_type_query, active=True).exists()
- )
-
- def test_migration_left_clean_user_without_infractions(self):
- """Do users without infractions have no infractions after the migration?"""
- user_id, infraction_history = self.created_infractions["no infractions"]
- self.assertFalse(
- self.infraction_model.objects.filter(user__id=user_id).exists()
- )
-
- def test_migration_left_user_with_inactive_note_untouched(self):
- """Did the migration leave users with only an inactive note untouched?"""
- user_id, infraction_history = self.created_infractions["one inactive note"]
- inactive_note = infraction_history[0]
- self.assertTrue(
- self.infraction_model.objects.filter(**model_to_dict(inactive_note)).exists()
- )
-
- def test_migration_only_touched_active_field_of_active_note(self):
- """Does the migration only change the `active` field?"""
- user_id, infraction_history = self.created_infractions["one active note"]
- note = model_to_dict(infraction_history[0])
- note['active'] = False
- self.assertTrue(
- self.infraction_model.objects.filter(**note).exists()
- )
-
- def test_migration_only_touched_active_field_of_active_note_left_inactive_untouched(self):
- """Does the migration only change the `active` field of active notes?"""
- user_id, infraction_history = self.created_infractions["one active and one inactive note"]
- for note in infraction_history:
- with self.subTest(active=note.active):
- note = model_to_dict(note)
- note['active'] = False
- self.assertTrue(
- self.infraction_model.objects.filter(**note).exists()
- )
-
- def test_migration_migrates_all_nonactive_types_to_inactive(self):
- """Do we set the `active` field of all non-active infractions to `False`?"""
- user_id, infraction_history = self.created_infractions["active note, kick, warning"]
- self.assertFalse(
- self.infraction_model.objects.filter(user__id=user_id, active=True).exists()
- )
-
- def test_migration_leaves_user_with_one_active_ban_untouched(self):
- """Do we leave a user with one active and one inactive ban untouched?"""
- user_id, infraction_history = self.created_infractions["one inactive and one active ban"]
- for infraction in infraction_history:
- with self.subTest(active=infraction.active):
- self.assertTrue(
- self.infraction_model.objects.filter(**model_to_dict(infraction)).exists()
- )
-
- def test_migration_turns_double_active_perm_ban_into_single_active_perm_ban(self):
- """Does the migration turn two active permanent bans into one active permanent ban?"""
- user_id, infraction_history = self.created_infractions["two active perm bans"]
- active_count = self.infraction_model.objects.filter(user__id=user_id, active=True).count()
- self.assertEqual(active_count, 1)
-
- def test_migration_leaves_temporary_ban_with_longest_duration_active(self):
- """Does the migration turn two active permanent bans into one active permanent ban?"""
- user_id, infraction_history = self.created_infractions["multiple active temp bans"]
- active_ban = self.infraction_model.objects.get(user__id=user_id, active=True)
- self.assertEqual(active_ban.expires_at, infraction_history[2].expires_at)
-
- def test_migration_leaves_permanent_ban_active(self):
- """Does the migration leave the permanent ban active?"""
- user_id, infraction_history = self.created_infractions["active perm, two active temp bans"]
- active_ban = self.infraction_model.objects.get(user__id=user_id, active=True)
- self.assertIsNone(active_ban.expires_at)
-
- def test_migration_leaves_longest_temp_ban_active_with_inactive_permanent_ban(self):
- """Does the longest temp ban stay active, even with an inactive perm ban present?"""
- user_id, infraction_history = self.created_infractions[
- "one inactive perm ban, two active temp bans"
- ]
- active_ban = self.infraction_model.objects.get(user__id=user_id, active=True)
- self.assertEqual(active_ban.expires_at, infraction_history[0].expires_at)
-
- def test_migration_leaves_all_active_types_active_if_one_of_each_exists(self):
- """Do all active infractions stay active if only one of each is present?"""
- user_id, infraction_history = self.created_infractions["active ban, mute, and superstar"]
- active_count = self.infraction_model.objects.filter(user__id=user_id, active=True).count()
- self.assertEqual(active_count, 4)
-
- def test_migration_reduces_all_active_types_to_a_single_active_infraction(self):
- """Do we reduce all of the infraction types to one active infraction?"""
- user_id, infraction_history = self.created_infractions["multiple active bans, mutes, stars"]
- active_infractions = self.infraction_model.objects.filter(user__id=user_id, active=True)
- self.assertEqual(len(active_infractions), 4)
- types_observed = [infraction.type for infraction in active_infractions]
-
- for infraction_type in ('ban', 'mute', 'superstar', 'watch'):
- with self.subTest(type=infraction_type):
- self.assertIn(infraction_type, types_observed)
diff --git a/pydis_site/apps/api/tests/migrations/test_base.py b/pydis_site/apps/api/tests/migrations/test_base.py
deleted file mode 100644
index f69bc92c..00000000
--- a/pydis_site/apps/api/tests/migrations/test_base.py
+++ /dev/null
@@ -1,135 +0,0 @@
-import logging
-from unittest.mock import call, patch
-
-from django.db.migrations.loader import MigrationLoader
-from django.test import TestCase
-
-from .base import MigrationsTestCase, connection
-
-log = logging.getLogger(__name__)
-
-
-class SpanishInquisition(MigrationsTestCase):
- app = "api"
- migration_prior = "scragly"
- migration_target = "kosa"
-
-
-@patch("pydis_site.apps.api.tests.migrations.base.MigrationExecutor")
-class MigrationsTestCaseNoSideEffectsTests(TestCase):
- """Tests the MigrationTestCase class with actual migration side effects disabled."""
-
- def setUp(self):
- """Set up an instance of MigrationsTestCase for use in tests."""
- self.test_case = SpanishInquisition()
-
- def test_missing_app_class_raises_value_error(self, _migration_executor):
- """A MigrationsTestCase subclass should set the class-attribute `app`."""
- class Spam(MigrationsTestCase):
- pass
-
- spam = Spam()
- with self.assertRaises(ValueError, msg="The `app` attribute was not set."):
- spam.setUpTestData()
-
- def test_missing_migration_class_attributes_raise_value_error(self, _migration_executor):
- """A MigrationsTestCase subclass should set both `migration_prior` and `migration_target`"""
- class Eggs(MigrationsTestCase):
- app = "api"
- migration_target = "lemon"
-
- class Bacon(MigrationsTestCase):
- app = "api"
- migration_prior = "mark"
-
- instances = (Eggs(), Bacon())
-
- exception_message = "Both ` migration_prior` and `migration_target` need to be set."
- for instance in instances:
- with self.subTest(
- migration_prior=instance.migration_prior,
- migration_target=instance.migration_target,
- ):
- with self.assertRaises(ValueError, msg=exception_message):
- instance.setUpTestData()
-
- @patch(f"{__name__}.SpanishInquisition.setUpMigrationData")
- @patch(f"{__name__}.SpanishInquisition.setUpPostMigrationData")
- def test_migration_data_hooks_are_called_once(self, pre_hook, post_hook, _migration_executor):
- """The `setUpMigrationData` and `setUpPostMigrationData` hooks should be called once."""
- self.test_case.setUpTestData()
- for hook in (pre_hook, post_hook):
- with self.subTest(hook=repr(hook)):
- hook.assert_called_once()
-
- def test_migration_executor_is_instantiated_twice(self, migration_executor):
- """The `MigrationExecutor` should be instantiated with the database connection twice."""
- self.test_case.setUpTestData()
-
- expected_args = [call(connection), call(connection)]
- self.assertEqual(migration_executor.call_args_list, expected_args)
-
- def test_project_state_is_loaded_for_correct_migration_files_twice(self, migration_executor):
- """The `project_state` should first be loaded with `migrate_from`, then `migrate_to`."""
- self.test_case.setUpTestData()
-
- expected_args = [call(self.test_case.migrate_from), call(self.test_case.migrate_to)]
- self.assertEqual(migration_executor().loader.project_state.call_args_list, expected_args)
-
- def test_loader_build_graph_gets_called_once(self, migration_executor):
- """We should rebuild the migration graph before applying the second set of migrations."""
- self.test_case.setUpTestData()
-
- migration_executor().loader.build_graph.assert_called_once()
-
- def test_migration_executor_migrate_method_is_called_correctly_twice(self, migration_executor):
- """The migrate method of the executor should be called twice with the correct arguments."""
- self.test_case.setUpTestData()
-
- self.assertEqual(migration_executor().migrate.call_count, 2)
- calls = [call([('api', 'scragly')]), call([('api', 'kosa')])]
- migration_executor().migrate.assert_has_calls(calls)
-
-
-class LifeOfBrian(MigrationsTestCase):
- app = "api"
- migration_prior = "0046_reminder_jump_url"
- migration_target = "0048_add_infractions_unique_constraints_active"
-
- @classmethod
- def log_last_migration(cls):
- """Parses the applied migrations dictionary to log the last applied migration."""
- loader = MigrationLoader(connection)
- api_migrations = [
- migration for app, migration in loader.applied_migrations if app == cls.app
- ]
- last_migration = max(api_migrations, key=lambda name: int(name[:4]))
- log.info(f"The last applied migration: {last_migration}")
-
- @classmethod
- def setUpMigrationData(cls, apps):
- """Method that logs the last applied migration at this point."""
- cls.log_last_migration()
-
- @classmethod
- def setUpPostMigrationData(cls, apps):
- """Method that logs the last applied migration at this point."""
- cls.log_last_migration()
-
-
-class MigrationsTestCaseMigrationTest(TestCase):
- """Tests if `MigrationsTestCase` travels to the right points in the migration history."""
-
- def test_migrations_test_case_travels_to_correct_migrations_in_history(self):
- """The test case should first revert to `migration_prior`, then go to `migration_target`."""
- brian = LifeOfBrian()
-
- with self.assertLogs(log, level=logging.INFO) as logs:
- brian.setUpTestData()
-
- self.assertEqual(len(logs.records), 2)
-
- for time_point, record in zip(("migration_prior", "migration_target"), logs.records):
- with self.subTest(time_point=time_point):
- message = f"The last applied migration: {getattr(brian, time_point)}"
- self.assertEqual(record.getMessage(), message)
diff --git a/pydis_site/apps/api/tests/test_filterlists.py b/pydis_site/apps/api/tests/test_filterlists.py
index 5a5bca60..9959617e 100644
--- a/pydis_site/apps/api/tests/test_filterlists.py
+++ b/pydis_site/apps/api/tests/test_filterlists.py
@@ -64,8 +64,8 @@ class FetchTests(AuthenticatedAPITestCase):
self.assertEqual(response.status_code, 200)
for api_type, model_type in zip(response.json(), FilterList.FilterListType.choices):
- self.assertEquals(api_type[0], model_type[0])
- self.assertEquals(api_type[1], model_type[1])
+ self.assertEqual(api_type[0], model_type[0])
+ self.assertEqual(api_type[1], model_type[1])
class CreationTests(AuthenticatedAPITestCase):
diff --git a/pydis_site/apps/api/tests/test_github_utils.py b/pydis_site/apps/api/tests/test_github_utils.py
new file mode 100644
index 00000000..95bafec0
--- /dev/null
+++ b/pydis_site/apps/api/tests/test_github_utils.py
@@ -0,0 +1,286 @@
+import dataclasses
+import datetime
+import typing
+import unittest
+from unittest import mock
+
+import django.test
+import httpx
+import jwt
+import rest_framework.response
+import rest_framework.test
+from django.urls import reverse
+
+from pydis_site import settings
+from .. import github_utils
+
+
+class GeneralUtilityTests(unittest.TestCase):
+ """Test the utility methods which do not fit in another class."""
+
+ def test_token_generation(self):
+ """Test that the a valid JWT token is generated."""
+ def encode(payload: dict, _: str, algorithm: str, *args, **kwargs) -> str:
+ """
+ Intercept the encode method.
+
+ The result is encoded with an algorithm which does not require a PEM key, as it may
+ not be available in testing environments.
+ """
+ self.assertEqual("RS256", algorithm, "The GitHub App JWT must be signed using RS256.")
+ return original_encode(
+ payload, "secret-encoding-key", *args, algorithm="HS256", **kwargs
+ )
+
+ original_encode = jwt.encode
+ with mock.patch("jwt.encode", new=encode):
+ token = github_utils.generate_token()
+ decoded = jwt.decode(token, "secret-encoding-key", algorithms=["HS256"])
+
+ delta = datetime.timedelta(minutes=10)
+ self.assertAlmostEqual(decoded["exp"] - decoded["iat"], delta.total_seconds())
+ self.assertLess(decoded["exp"], (datetime.datetime.now() + delta).timestamp())
+
+
+class CheckRunTests(unittest.TestCase):
+ """Tests the check_run_status utility."""
+
+ run_kwargs: typing.Mapping = {
+ "name": "run_name",
+ "head_sha": "sha",
+ "status": "completed",
+ "conclusion": "success",
+ "created_at": datetime.datetime.utcnow().strftime(settings.GITHUB_TIMESTAMP_FORMAT),
+ "artifacts_url": "url",
+ }
+
+ def test_completed_run(self):
+ """Test that an already completed run returns the correct URL."""
+ final_url = "some_url_string_1234"
+
+ kwargs = dict(self.run_kwargs, artifacts_url=final_url)
+ result = github_utils.check_run_status(github_utils.WorkflowRun(**kwargs))
+ self.assertEqual(final_url, result)
+
+ def test_pending_run(self):
+ """Test that a pending run raises the proper exception."""
+ kwargs = dict(self.run_kwargs, status="pending")
+ with self.assertRaises(github_utils.RunPendingError):
+ github_utils.check_run_status(github_utils.WorkflowRun(**kwargs))
+
+ def test_timeout_error(self):
+ """Test that a timeout is declared after a certain duration."""
+ kwargs = dict(self.run_kwargs, status="pending")
+ # Set the creation time to well before the MAX_RUN_TIME
+ # to guarantee the right conclusion
+ kwargs["created_at"] = (
+ datetime.datetime.utcnow() - github_utils.MAX_RUN_TIME - datetime.timedelta(minutes=10)
+ ).strftime(settings.GITHUB_TIMESTAMP_FORMAT)
+
+ with self.assertRaises(github_utils.RunTimeoutError):
+ github_utils.check_run_status(github_utils.WorkflowRun(**kwargs))
+
+ def test_failed_run(self):
+ """Test that a failed run raises the proper exception."""
+ kwargs = dict(self.run_kwargs, conclusion="failed")
+ with self.assertRaises(github_utils.ActionFailedError):
+ github_utils.check_run_status(github_utils.WorkflowRun(**kwargs))
+
+
+def get_response_authorize(_: httpx.Client, request: httpx.Request, **__) -> httpx.Response:
+ """
+ Helper method for the authorize tests.
+
+ Requests are intercepted before being sent out, and the appropriate responses are returned.
+ """
+ path = request.url.path
+ auth = request.headers.get("Authorization")
+
+ if request.method == "GET":
+ if path == "/app/installations":
+ if auth == "bearer JWT initial token":
+ return httpx.Response(200, request=request, json=[{
+ "account": {"login": "VALID_OWNER"},
+ "access_tokens_url": "https://example.com/ACCESS_TOKEN_URL"
+ }])
+ else:
+ return httpx.Response(
+ 401, json={"error": "auth app/installations"}, request=request
+ )
+
+ elif path == "/installation/repositories":
+ if auth == "bearer app access token":
+ return httpx.Response(200, request=request, json={
+ "repositories": [{
+ "name": "VALID_REPO"
+ }]
+ })
+ else: # pragma: no cover
+ return httpx.Response(
+ 401, json={"error": "auth installation/repositories"}, request=request
+ )
+
+ elif request.method == "POST":
+ if path == "/ACCESS_TOKEN_URL":
+ if auth == "bearer JWT initial token":
+ return httpx.Response(200, request=request, json={"token": "app access token"})
+ else: # pragma: no cover
+ return httpx.Response(401, json={"error": "auth access_token"}, request=request)
+
+ # Reaching this point means something has gone wrong
+ return httpx.Response(500, request=request) # pragma: no cover
+
+
[email protected]("httpx.Client.send", new=get_response_authorize)
[email protected](github_utils, "generate_token", new=mock.Mock(return_value="JWT initial token"))
+class AuthorizeTests(unittest.TestCase):
+ """Test the authorize utility."""
+
+ def test_invalid_apps_auth(self):
+ """Test that an exception is raised if authorization was attempted with an invalid token."""
+ with mock.patch.object(github_utils, "generate_token", return_value="Invalid token"):
+ with self.assertRaises(httpx.HTTPStatusError) as error:
+ github_utils.authorize("VALID_OWNER", "VALID_REPO")
+
+ exception: httpx.HTTPStatusError = error.exception
+ self.assertEqual(401, exception.response.status_code)
+ self.assertEqual("auth app/installations", exception.response.json()["error"])
+
+ def test_missing_repo(self):
+ """Test that an exception is raised when the selected owner or repo are not available."""
+ with self.assertRaises(github_utils.NotFoundError):
+ github_utils.authorize("INVALID_OWNER", "VALID_REPO")
+ with self.assertRaises(github_utils.NotFoundError):
+ github_utils.authorize("VALID_OWNER", "INVALID_REPO")
+
+ def test_valid_authorization(self):
+ """Test that an accessible repository can be accessed."""
+ client = github_utils.authorize("VALID_OWNER", "VALID_REPO")
+ self.assertEqual("bearer app access token", client.headers.get("Authorization"))
+
+
+class ArtifactFetcherTests(unittest.TestCase):
+ """Test the get_artifact utility."""
+
+ @staticmethod
+ def get_response_get_artifact(request: httpx.Request, **_) -> httpx.Response:
+ """
+ Helper method for the get_artifact tests.
+
+ Requests are intercepted before being sent out, and the appropriate responses are returned.
+ """
+ path = request.url.path
+
+ if "force_error" in path:
+ return httpx.Response(404, request=request)
+
+ if request.method == "GET":
+ if path == "/repos/owner/repo/actions/runs":
+ run = github_utils.WorkflowRun(
+ name="action_name",
+ head_sha="action_sha",
+ created_at=datetime.datetime.now().strftime(settings.GITHUB_TIMESTAMP_FORMAT),
+ status="completed",
+ conclusion="success",
+ artifacts_url="artifacts_url"
+ )
+ return httpx.Response(
+ 200, request=request, json={"workflow_runs": [dataclasses.asdict(run)]}
+ )
+ elif path == "/artifact_url":
+ return httpx.Response(
+ 200, request=request, json={"artifacts": [{
+ "name": "artifact_name",
+ "archive_download_url": "artifact_download_url"
+ }]}
+ )
+ elif path == "/artifact_download_url":
+ response = httpx.Response(302, request=request)
+ response.next_request = httpx.Request(
+ "GET",
+ httpx.URL("https://final_download.url")
+ )
+ return response
+
+ # Reaching this point means something has gone wrong
+ return httpx.Response(500, request=request) # pragma: no cover
+
+ def setUp(self) -> None:
+ self.call_args = ["owner", "repo", "action_sha", "action_name", "artifact_name"]
+ self.client = httpx.Client(base_url="https://example.com")
+
+ self.patchers = [
+ mock.patch.object(self.client, "send", new=self.get_response_get_artifact),
+ mock.patch.object(github_utils, "authorize", return_value=self.client),
+ mock.patch.object(github_utils, "check_run_status", return_value="artifact_url"),
+ ]
+
+ for patcher in self.patchers:
+ patcher.start()
+
+ def tearDown(self) -> None:
+ for patcher in self.patchers:
+ patcher.stop()
+
+ def test_client_closed_on_errors(self):
+ """Test that the client is terminated even if an error occurs at some point."""
+ self.call_args[0] = "force_error"
+ with self.assertRaises(httpx.HTTPStatusError):
+ github_utils.get_artifact(*self.call_args)
+ self.assertTrue(self.client.is_closed)
+
+ def test_missing(self):
+ """Test that an exception is raised if the requested artifact was not found."""
+ cases = (
+ "invalid sha",
+ "invalid action name",
+ "invalid artifact name",
+ )
+ for i, name in enumerate(cases, 2):
+ with self.subTest(f"Test {name} raises an error"):
+ new_args = self.call_args.copy()
+ new_args[i] = name
+
+ with self.assertRaises(github_utils.NotFoundError):
+ github_utils.get_artifact(*new_args)
+
+ def test_valid(self):
+ """Test that the correct download URL is returned for valid requests."""
+ url = github_utils.get_artifact(*self.call_args)
+ self.assertEqual("https://final_download.url", url)
+ self.assertTrue(self.client.is_closed)
+
+
[email protected](github_utils, "get_artifact")
+class GitHubArtifactViewTests(django.test.TestCase):
+ """Test the GitHub artifact fetch API view."""
+
+ def setUp(self):
+ self.kwargs = {
+ "owner": "test_owner",
+ "repo": "test_repo",
+ "sha": "test_sha",
+ "action_name": "test_action",
+ "artifact_name": "test_artifact",
+ }
+ self.url = reverse("api:github-artifacts", kwargs=self.kwargs)
+
+ def test_correct_artifact(self, artifact_mock: mock.Mock):
+ """Test a proper response is returned with proper input."""
+ artifact_mock.return_value = "final download url"
+ result = self.client.get(self.url)
+
+ self.assertIsInstance(result, rest_framework.response.Response)
+ self.assertEqual({"url": artifact_mock.return_value}, result.data)
+
+ def test_failed_fetch(self, artifact_mock: mock.Mock):
+ """Test that a proper error is returned when the request fails."""
+ artifact_mock.side_effect = github_utils.NotFoundError("Test error message")
+ result = self.client.get(self.url)
+
+ self.assertIsInstance(result, rest_framework.response.Response)
+ self.assertEqual({
+ "error_type": github_utils.NotFoundError.__name__,
+ "error": "Test error message",
+ "requested_resource": "/".join(self.kwargs.values())
+ }, result.data)
diff --git a/pydis_site/apps/api/tests/test_infractions.py b/pydis_site/apps/api/tests/test_infractions.py
index f1107734..89ee4e23 100644
--- a/pydis_site/apps/api/tests/test_infractions.py
+++ b/pydis_site/apps/api/tests/test_infractions.py
@@ -56,15 +56,17 @@ class InfractionTests(AuthenticatedAPITestCase):
type='ban',
reason='He terk my jerb!',
hidden=True,
+ inserted_at=dt(2020, 10, 10, 0, 0, 0, tzinfo=timezone.utc),
expires_at=dt(5018, 11, 20, 15, 52, tzinfo=timezone.utc),
- active=True
+ active=True,
)
cls.ban_inactive = Infraction.objects.create(
user_id=cls.user.id,
actor_id=cls.user.id,
type='ban',
reason='James is an ass, and we won\'t be working with him again.',
- active=False
+ active=False,
+ inserted_at=dt(2020, 10, 10, 0, 1, 0, tzinfo=timezone.utc),
)
cls.mute_permanent = Infraction.objects.create(
user_id=cls.user.id,
@@ -72,7 +74,8 @@ class InfractionTests(AuthenticatedAPITestCase):
type='mute',
reason='He has a filthy mouth and I am his soap.',
active=True,
- expires_at=None
+ inserted_at=dt(2020, 10, 10, 0, 2, 0, tzinfo=timezone.utc),
+ expires_at=None,
)
cls.superstar_expires_soon = Infraction.objects.create(
user_id=cls.user.id,
@@ -80,7 +83,8 @@ class InfractionTests(AuthenticatedAPITestCase):
type='superstar',
reason='This one doesn\'t matter anymore.',
active=True,
- expires_at=dt.now(timezone.utc) + datetime.timedelta(hours=5)
+ inserted_at=dt(2020, 10, 10, 0, 3, 0, tzinfo=timezone.utc),
+ expires_at=dt.now(timezone.utc) + datetime.timedelta(hours=5),
)
cls.voiceban_expires_later = Infraction.objects.create(
user_id=cls.user.id,
@@ -88,7 +92,8 @@ class InfractionTests(AuthenticatedAPITestCase):
type='voice_ban',
reason='Jet engine mic',
active=True,
- expires_at=dt.now(timezone.utc) + datetime.timedelta(days=5)
+ inserted_at=dt(2020, 10, 10, 0, 4, 0, tzinfo=timezone.utc),
+ expires_at=dt.now(timezone.utc) + datetime.timedelta(days=5),
)
def test_list_all(self):
diff --git a/pydis_site/apps/api/tests/test_models.py b/pydis_site/apps/api/tests/test_models.py
index 0fad467c..c07d59cd 100644
--- a/pydis_site/apps/api/tests/test_models.py
+++ b/pydis_site/apps/api/tests/test_models.py
@@ -7,7 +7,6 @@ from pydis_site.apps.api.models import (
DeletedMessage,
DocumentationLink,
Infraction,
- Message,
MessageDeletionContext,
Nomination,
NominationEntry,
@@ -116,17 +115,6 @@ class StringDunderMethodTests(SimpleTestCase):
colour=0x5, permissions=0,
position=10,
),
- Message(
- id=45,
- author=User(
- id=444,
- name='bill',
- discriminator=5,
- ),
- channel_id=666,
- content="wooey",
- embeds=[]
- ),
MessageDeletionContext(
actor=User(
id=5555,
diff --git a/pydis_site/apps/api/tests/test_users.py b/pydis_site/apps/api/tests/test_users.py
index 5d10069d..d86e80bb 100644
--- a/pydis_site/apps/api/tests/test_users.py
+++ b/pydis_site/apps/api/tests/test_users.py
@@ -502,6 +502,90 @@ class UserMetricityTests(AuthenticatedAPITestCase):
"total_messages": total_messages
})
+ def test_metricity_activity_data(self):
+ # Given
+ self.mock_no_metricity_user() # Other functions shouldn't be used.
+ self.metricity.total_messages_in_past_n_days.return_value = [(0, 10)]
+
+ # When
+ url = reverse("api:bot:user-metricity-activity-data")
+ response = self.client.post(
+ url,
+ data=[0, 1],
+ QUERY_STRING="days=10",
+ )
+
+ # Then
+ self.assertEqual(response.status_code, 200)
+ self.metricity.total_messages_in_past_n_days.assert_called_once_with(["0", "1"], 10)
+ self.assertEqual(response.json(), {"0": 10, "1": 0})
+
+ def test_metricity_activity_data_invalid_days(self):
+ # Given
+ self.mock_no_metricity_user() # Other functions shouldn't be used.
+
+ # When
+ url = reverse("api:bot:user-metricity-activity-data")
+ response = self.client.post(
+ url,
+ data=[0, 1],
+ QUERY_STRING="days=fifty",
+ )
+
+ # Then
+ self.assertEqual(response.status_code, 400)
+ self.metricity.total_messages_in_past_n_days.assert_not_called()
+ self.assertEqual(response.json(), {"days": ["This query parameter must be an integer."]})
+
+ def test_metricity_activity_data_no_days(self):
+ # Given
+ self.mock_no_metricity_user() # Other functions shouldn't be used.
+
+ # When
+ url = reverse('api:bot:user-metricity-activity-data')
+ response = self.client.post(
+ url,
+ data=[0, 1],
+ )
+
+ # Then
+ self.assertEqual(response.status_code, 400)
+ self.metricity.total_messages_in_past_n_days.assert_not_called()
+ self.assertEqual(response.json(), {'days': ["This query parameter is required."]})
+
+ def test_metricity_activity_data_no_users(self):
+ # Given
+ self.mock_no_metricity_user() # Other functions shouldn't be used.
+
+ # When
+ url = reverse('api:bot:user-metricity-activity-data')
+ response = self.client.post(
+ url,
+ QUERY_STRING="days=10",
+ )
+
+ # Then
+ self.assertEqual(response.status_code, 400)
+ self.metricity.total_messages_in_past_n_days.assert_not_called()
+ self.assertEqual(response.json(), ['Expected a list of items but got type "dict".'])
+
+ def test_metricity_activity_data_invalid_users(self):
+ # Given
+ self.mock_no_metricity_user() # Other functions shouldn't be used.
+
+ # When
+ url = reverse('api:bot:user-metricity-activity-data')
+ response = self.client.post(
+ url,
+ data=[123, 'username'],
+ QUERY_STRING="days=10",
+ )
+
+ # Then
+ self.assertEqual(response.status_code, 400)
+ self.metricity.total_messages_in_past_n_days.assert_not_called()
+ self.assertEqual(response.json(), {'1': ['A valid integer is required.']})
+
def mock_metricity_user(self, joined_at, total_messages, total_blocks, top_channel_activity):
patcher = patch("pydis_site.apps.api.viewsets.bot.user.Metricity")
self.metricity = patcher.start()
diff --git a/pydis_site/apps/api/tests/test_validators.py b/pydis_site/apps/api/tests/test_validators.py
index 551cc2aa..8c46fcbc 100644
--- a/pydis_site/apps/api/tests/test_validators.py
+++ b/pydis_site/apps/api/tests/test_validators.py
@@ -5,7 +5,6 @@ from django.test import TestCase
from ..models.bot.bot_setting import validate_bot_setting_name
from ..models.bot.offensive_message import future_date_validator
-from ..models.utils import validate_embed
REQUIRED_KEYS = (
@@ -22,234 +21,6 @@ class BotSettingValidatorTests(TestCase):
validate_bot_setting_name('bad name')
-class TagEmbedValidatorTests(TestCase):
- def test_rejects_non_mapping(self):
- with self.assertRaises(ValidationError):
- validate_embed('non-empty non-mapping')
-
- def test_rejects_missing_required_keys(self):
- with self.assertRaises(ValidationError):
- validate_embed({
- 'unknown': "key"
- })
-
- def test_rejects_one_correct_one_incorrect(self):
- with self.assertRaises(ValidationError):
- validate_embed({
- 'provider': "??",
- 'title': ""
- })
-
- def test_rejects_empty_required_key(self):
- with self.assertRaises(ValidationError):
- validate_embed({
- 'title': ''
- })
-
- def test_rejects_list_as_embed(self):
- with self.assertRaises(ValidationError):
- validate_embed([])
-
- def test_rejects_required_keys_and_unknown_keys(self):
- with self.assertRaises(ValidationError):
- validate_embed({
- 'title': "the duck walked up to the lemonade stand",
- 'and': "he said to the man running the stand"
- })
-
- def test_rejects_too_long_title(self):
- with self.assertRaises(ValidationError):
- validate_embed({
- 'title': 'a' * 257
- })
-
- def test_rejects_too_many_fields(self):
- with self.assertRaises(ValidationError):
- validate_embed({
- 'fields': [{} for _ in range(26)]
- })
-
- def test_rejects_too_long_description(self):
- with self.assertRaises(ValidationError):
- validate_embed({
- 'description': 'd' * 4097
- })
-
- def test_allows_valid_embed(self):
- validate_embed({
- 'title': "My embed",
- 'description': "look at my embed, my embed is amazing"
- })
-
- def test_allows_unvalidated_fields(self):
- validate_embed({
- 'title': "My embed",
- 'provider': "what am I??"
- })
-
- def test_rejects_fields_as_list_of_non_mappings(self):
- with self.assertRaises(ValidationError):
- validate_embed({
- 'fields': ['abc']
- })
-
- def test_rejects_fields_with_unknown_fields(self):
- with self.assertRaises(ValidationError):
- validate_embed({
- 'fields': [
- {
- 'what': "is this field"
- }
- ]
- })
-
- def test_rejects_fields_with_too_long_name(self):
- with self.assertRaises(ValidationError):
- validate_embed({
- 'fields': [
- {
- 'name': "a" * 257
- }
- ]
- })
-
- def test_rejects_one_correct_one_incorrect_field(self):
- with self.assertRaises(ValidationError):
- validate_embed({
- 'fields': [
- {
- 'name': "Totally valid",
- 'value': "LOOK AT ME"
- },
- {
- 'name': "Totally valid",
- 'value': "LOOK AT ME",
- 'oh': "what is this key?"
- }
- ]
- })
-
- def test_rejects_missing_required_field_field(self):
- with self.assertRaises(ValidationError):
- validate_embed({
- 'fields': [
- {
- 'name': "Totally valid",
- 'inline': True,
- }
- ]
- })
-
- def test_rejects_invalid_inline_field_field(self):
- with self.assertRaises(ValidationError):
- validate_embed({
- 'fields': [
- {
- 'name': "Totally valid",
- 'value': "LOOK AT ME",
- 'inline': "Totally not a boolean",
- }
- ]
- })
-
- def test_allows_valid_fields(self):
- validate_embed({
- 'fields': [
- {
- 'name': "valid",
- 'value': "field",
- },
- {
- 'name': "valid",
- 'value': "field",
- 'inline': False,
- },
- {
- 'name': "valid",
- 'value': "field",
- 'inline': True,
- },
- ]
- })
-
- def test_rejects_footer_as_non_mapping(self):
- with self.assertRaises(ValidationError):
- validate_embed({
- 'title': "whatever",
- 'footer': []
- })
-
- def test_rejects_footer_with_unknown_fields(self):
- with self.assertRaises(ValidationError):
- validate_embed({
- 'title': "whatever",
- 'footer': {
- 'duck': "quack"
- }
- })
-
- def test_rejects_footer_with_empty_text(self):
- with self.assertRaises(ValidationError):
- validate_embed({
- 'title': "whatever",
- 'footer': {
- 'text': ""
- }
- })
-
- def test_allows_footer_with_proper_values(self):
- validate_embed({
- 'title': "whatever",
- 'footer': {
- 'text': "django good"
- }
- })
-
- def test_rejects_author_as_non_mapping(self):
- with self.assertRaises(ValidationError):
- validate_embed({
- 'title': "whatever",
- 'author': []
- })
-
- def test_rejects_author_with_unknown_field(self):
- with self.assertRaises(ValidationError):
- validate_embed({
- 'title': "whatever",
- 'author': {
- 'field': "that is unknown"
- }
- })
-
- def test_rejects_author_with_empty_name(self):
- with self.assertRaises(ValidationError):
- validate_embed({
- 'title': "whatever",
- 'author': {
- 'name': ""
- }
- })
-
- def test_rejects_author_with_one_correct_one_incorrect(self):
- with self.assertRaises(ValidationError):
- validate_embed({
- 'title': "whatever",
- 'author': {
- # Relies on "dictionary insertion order remembering" (D.I.O.R.) behaviour
- 'url': "bobswebsite.com",
- 'name': ""
- }
- })
-
- def test_allows_author_with_proper_values(self):
- validate_embed({
- 'title': "whatever",
- 'author': {
- 'name': "Bob"
- }
- })
-
-
class OffensiveMessageValidatorsTests(TestCase):
def test_accepts_future_date(self):
future_date_validator(datetime(3000, 1, 1, tzinfo=timezone.utc))
diff --git a/pydis_site/apps/api/urls.py b/pydis_site/apps/api/urls.py
index 1e564b29..2757f176 100644
--- a/pydis_site/apps/api/urls.py
+++ b/pydis_site/apps/api/urls.py
@@ -1,7 +1,7 @@
from django.urls import include, path
from rest_framework.routers import DefaultRouter
-from .views import HealthcheckView, RulesView
+from .views import GitHubArtifactsView, HealthcheckView, RulesView
from .viewsets import (
AocAccountLinkViewSet,
AocCompletionistBlockViewSet,
@@ -86,5 +86,10 @@ urlpatterns = (
# from django_hosts.resolvers import reverse
path('bot/', include((bot_router.urls, 'api'), namespace='bot')),
path('healthcheck', HealthcheckView.as_view(), name='healthcheck'),
- path('rules', RulesView.as_view(), name='rules')
+ path('rules', RulesView.as_view(), name='rules'),
+ path(
+ 'github/artifact/<str:owner>/<str:repo>/<str:sha>/<str:action_name>/<str:artifact_name>',
+ GitHubArtifactsView.as_view(),
+ name="github-artifacts"
+ ),
)
diff --git a/pydis_site/apps/api/views.py b/pydis_site/apps/api/views.py
index 816463f6..34167a38 100644
--- a/pydis_site/apps/api/views.py
+++ b/pydis_site/apps/api/views.py
@@ -1,7 +1,10 @@
from rest_framework.exceptions import ParseError
+from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.views import APIView
+from . import github_utils
+
class HealthcheckView(APIView):
"""
@@ -34,12 +37,14 @@ class RulesView(APIView):
## Routes
### GET /rules
- Returns a JSON array containing the server's rules:
+ Returns a JSON array containing the server's rules
+ and keywords relating to each rule.
+ Example response:
>>> [
- ... "Eat candy.",
- ... "Wake up at 4 AM.",
- ... "Take your medicine."
+ ... ["Eat candy.", ["candy", "sweets"]],
+ ... ["Wake up at 4 AM.", ["wake_up", "early", "early_bird"]],
+ ... ["Take your medicine.", ["medicine", "health"]]
... ]
Since some of the the rules require links, this view
@@ -97,6 +102,12 @@ class RulesView(APIView):
# `format` here is the result format, we have a link format here instead.
def get(self, request, format=None): # noqa: D102,ANN001,ANN201
+ """
+ Returns a list of our community rules coupled with their keywords.
+
+ Each item in the returned list is a tuple with the rule as first item
+ and a list of keywords that match that rules as second item.
+ """
link_format = request.query_params.get('link_format', 'md')
if link_format not in ('html', 'md'):
raise ParseError(
@@ -121,34 +132,93 @@ class RulesView(APIView):
return Response([
(
- f"Follow the {pydis_coc}."
+ f"Follow the {pydis_coc}.",
+ ["coc", "conduct", "code"]
),
(
- f"Follow the {discord_community_guidelines} and {discord_tos}."
+ f"Follow the {discord_community_guidelines} and {discord_tos}.",
+ ["discord", "guidelines", "discord_tos"]
),
(
- "Respect staff members and listen to their instructions."
+ "Respect staff members and listen to their instructions.",
+ ["respect", "staff", "instructions"]
),
(
"Use English to the best of your ability. "
- "Be polite if someone speaks English imperfectly."
+ "Be polite if someone speaks English imperfectly.",
+ ["english", "language"]
),
(
"Do not provide or request help on projects that may break laws, "
- "breach terms of services, or are malicious or inappropriate."
+ "breach terms of services, or are malicious or inappropriate.",
+ ["infraction", "tos", "breach", "malicious", "inappropriate"]
),
(
- "Do not post unapproved advertising."
+ "Do not post unapproved advertising.",
+ ["ad", "ads", "advert", "advertising"]
),
(
"Keep discussions relevant to the channel topic. "
- "Each channel's description tells you the topic."
+ "Each channel's description tells you the topic.",
+ ["off-topic", "topic", "relevance"]
),
(
"Do not help with ongoing exams. When helping with homework, "
- "help people learn how to do the assignment without doing it for them."
+ "help people learn how to do the assignment without doing it for them.",
+ ["exam", "exams", "assignment", "assignments", "homework"]
),
(
- "Do not offer or ask for paid work of any kind."
+ "Do not offer or ask for paid work of any kind.",
+ ["paid", "work", "money"]
),
])
+
+
+class GitHubArtifactsView(APIView):
+ """
+ Provides utilities for interacting with the GitHub API and obtaining action artifacts.
+
+ ## Routes
+ ### GET /github/artifacts
+ Returns a download URL for the artifact requested.
+
+ {
+ 'url': 'https://pipelines.actions.githubusercontent.com/...'
+ }
+
+ ### Exceptions
+ In case of an error, the following body will be returned:
+
+ {
+ "error_type": "<error class name>",
+ "error": "<error description>",
+ "requested_resource": "<owner>/<repo>/<sha>/<artifact_name>"
+ }
+
+ ## Authentication
+ Does not require any authentication nor permissions.
+ """
+
+ authentication_classes = ()
+ permission_classes = ()
+
+ def get(
+ self,
+ request: Request,
+ *,
+ owner: str,
+ repo: str,
+ sha: str,
+ action_name: str,
+ artifact_name: str
+ ) -> Response:
+ """Return a download URL for the requested artifact."""
+ try:
+ url = github_utils.get_artifact(owner, repo, sha, action_name, artifact_name)
+ return Response({"url": url})
+ except github_utils.ArtifactProcessingError as e:
+ return Response({
+ "error_type": e.__class__.__name__,
+ "error": str(e),
+ "requested_resource": f"{owner}/{repo}/{sha}/{action_name}/{artifact_name}"
+ }, status=e.status)
diff --git a/pydis_site/apps/api/viewsets/bot/aoc_completionist_block.py b/pydis_site/apps/api/viewsets/bot/aoc_completionist_block.py
index 3a4cec60..97efb63c 100644
--- a/pydis_site/apps/api/viewsets/bot/aoc_completionist_block.py
+++ b/pydis_site/apps/api/viewsets/bot/aoc_completionist_block.py
@@ -70,4 +70,4 @@ class AocCompletionistBlockViewSet(
serializer_class = AocCompletionistBlockSerializer
queryset = AocCompletionistBlock.objects.all()
filter_backends = (DjangoFilterBackend,)
- filter_fields = ("user__id", "is_blocked")
+ filterset_fields = ("user__id", "is_blocked")
diff --git a/pydis_site/apps/api/viewsets/bot/aoc_link.py b/pydis_site/apps/api/viewsets/bot/aoc_link.py
index c7a96629..3cdc342d 100644
--- a/pydis_site/apps/api/viewsets/bot/aoc_link.py
+++ b/pydis_site/apps/api/viewsets/bot/aoc_link.py
@@ -68,4 +68,4 @@ class AocAccountLinkViewSet(
serializer_class = AocAccountLinkSerializer
queryset = AocAccountLink.objects.all()
filter_backends = (DjangoFilterBackend,)
- filter_fields = ("user__id", "aoc_username")
+ filterset_fields = ("user__id", "aoc_username")
diff --git a/pydis_site/apps/api/viewsets/bot/infraction.py b/pydis_site/apps/api/viewsets/bot/infraction.py
index 7f31292f..93d29391 100644
--- a/pydis_site/apps/api/viewsets/bot/infraction.py
+++ b/pydis_site/apps/api/viewsets/bot/infraction.py
@@ -1,9 +1,8 @@
-from datetime import datetime
+import datetime
from django.db import IntegrityError
from django.db.models import QuerySet
from django.http.request import HttpRequest
-from django.utils import timezone
from django_filters.rest_framework import DjangoFilterBackend
from rest_framework.decorators import action
from rest_framework.exceptions import ValidationError
@@ -154,7 +153,7 @@ class InfractionViewSet(
queryset = Infraction.objects.all()
pagination_class = LimitOffsetPaginationExtended
filter_backends = (DjangoFilterBackend, SearchFilter, OrderingFilter)
- filter_fields = ('user__id', 'actor__id', 'active', 'hidden', 'type')
+ filterset_fields = ('user__id', 'actor__id', 'active', 'hidden', 'type')
search_fields = ('$reason',)
frozen_fields = ('id', 'inserted_at', 'type', 'user', 'actor', 'hidden')
@@ -185,23 +184,21 @@ class InfractionViewSet(
filter_expires_after = self.request.query_params.get('expires_after')
if filter_expires_after:
try:
- expires_after_parsed = datetime.fromisoformat(filter_expires_after)
+ expires_after_parsed = datetime.datetime.fromisoformat(filter_expires_after)
except ValueError:
raise ValidationError({'expires_after': ['failed to convert to datetime']})
- additional_filters['expires_at__gte'] = timezone.make_aware(
- expires_after_parsed,
- timezone=timezone.utc,
+ additional_filters['expires_at__gte'] = expires_after_parsed.replace(
+ tzinfo=datetime.timezone.utc
)
filter_expires_before = self.request.query_params.get('expires_before')
if filter_expires_before:
try:
- expires_before_parsed = datetime.fromisoformat(filter_expires_before)
+ expires_before_parsed = datetime.datetime.fromisoformat(filter_expires_before)
except ValueError:
raise ValidationError({'expires_before': ['failed to convert to datetime']})
- additional_filters['expires_at__lte'] = timezone.make_aware(
- expires_before_parsed,
- timezone=timezone.utc,
+ additional_filters['expires_at__lte'] = expires_before_parsed.replace(
+ tzinfo=datetime.timezone.utc
)
if 'expires_at__lte' in additional_filters and 'expires_at__gte' in additional_filters:
diff --git a/pydis_site/apps/api/viewsets/bot/nomination.py b/pydis_site/apps/api/viewsets/bot/nomination.py
index 144daab0..6af42bcb 100644
--- a/pydis_site/apps/api/viewsets/bot/nomination.py
+++ b/pydis_site/apps/api/viewsets/bot/nomination.py
@@ -172,7 +172,7 @@ class NominationViewSet(CreateModelMixin, RetrieveModelMixin, ListModelMixin, Ge
serializer_class = NominationSerializer
queryset = Nomination.objects.all()
filter_backends = (DjangoFilterBackend, SearchFilter, OrderingFilter)
- filter_fields = ('user__id', 'active')
+ filterset_fields = ('user__id', 'active')
frozen_fields = ('id', 'inserted_at', 'user', 'ended_at')
frozen_on_create = ('ended_at', 'end_reason', 'active', 'inserted_at', 'reviewed')
diff --git a/pydis_site/apps/api/viewsets/bot/reminder.py b/pydis_site/apps/api/viewsets/bot/reminder.py
index 78d7cb3b..5f997052 100644
--- a/pydis_site/apps/api/viewsets/bot/reminder.py
+++ b/pydis_site/apps/api/viewsets/bot/reminder.py
@@ -125,4 +125,4 @@ class ReminderViewSet(
serializer_class = ReminderSerializer
queryset = Reminder.objects.prefetch_related('author')
filter_backends = (DjangoFilterBackend, SearchFilter)
- filter_fields = ('active', 'author__id')
+ filterset_fields = ('active', 'author__id')
diff --git a/pydis_site/apps/api/viewsets/bot/user.py b/pydis_site/apps/api/viewsets/bot/user.py
index 3318b2b9..db73a83c 100644
--- a/pydis_site/apps/api/viewsets/bot/user.py
+++ b/pydis_site/apps/api/viewsets/bot/user.py
@@ -3,8 +3,9 @@ from collections import OrderedDict
from django.db.models import Q
from django_filters.rest_framework import DjangoFilterBackend
-from rest_framework import status
+from rest_framework import fields, status
from rest_framework.decorators import action
+from rest_framework.exceptions import ParseError
from rest_framework.pagination import PageNumberPagination
from rest_framework.request import Request
from rest_framework.response import Response
@@ -138,6 +139,29 @@ class UserViewSet(ModelViewSet):
- 200: returned on success
- 404: if a user with the given `snowflake` could not be found
+ ### POST /bot/users/metricity_activity_data
+ Returns a mapping of user ID to message count in a given period for
+ the given user IDs.
+
+ #### Required Query Parameters
+ - days: how many days into the past to count message from.
+
+ #### Request Format
+ >>> [
+ ... 409107086526644234,
+ ... 493839819168808962
+ ... ]
+
+ #### Response format
+ >>> {
+ ... "409107086526644234": 54,
+ ... "493839819168808962": 0
+ ... }
+
+ #### Status codes
+ - 200: returned on success
+ - 400: if request body or query parameters were missing or invalid
+
### POST /bot/users
Adds a single or multiple new users.
The roles attached to the user(s) must be roles known by the site.
@@ -237,7 +261,7 @@ class UserViewSet(ModelViewSet):
queryset = User.objects.all().order_by("id")
pagination_class = UserListPagination
filter_backends = (DjangoFilterBackend,)
- filter_fields = ('name', 'discriminator')
+ filterset_fields = ('name', 'discriminator')
def get_serializer(self, *args, **kwargs) -> ModelSerializer:
"""Set Serializer many attribute to True if request body contains a list."""
@@ -298,3 +322,34 @@ class UserViewSet(ModelViewSet):
except NotFoundError:
return Response(dict(detail="User not found in metricity"),
status=status.HTTP_404_NOT_FOUND)
+
+ @action(detail=False, methods=["POST"])
+ def metricity_activity_data(self, request: Request) -> Response:
+ """Request handler for metricity_activity_data endpoint."""
+ if "days" in request.query_params:
+ try:
+ days = int(request.query_params["days"])
+ except ValueError:
+ raise ParseError(detail={
+ "days": ["This query parameter must be an integer."]
+ })
+ else:
+ raise ParseError(detail={
+ "days": ["This query parameter is required."]
+ })
+
+ user_id_list_validator = fields.ListField(
+ child=fields.IntegerField(min_value=0),
+ allow_empty=False
+ )
+ user_ids = [
+ str(user_id) for user_id in
+ user_id_list_validator.run_validation(request.data)
+ ]
+
+ with Metricity() as metricity:
+ data = metricity.total_messages_in_past_n_days(user_ids, days)
+
+ default_data = {user_id: 0 for user_id in user_ids}
+ response_data = default_data | dict(data)
+ return Response(response_data, status=status.HTTP_200_OK)