aboutsummaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
authorGravatar Boris Muratov <[email protected]>2024-03-26 16:20:12 +0200
committerGravatar GitHub <[email protected]>2024-03-26 16:20:12 +0200
commitb7980e30f380932155450893bd87705aea8eb848 (patch)
treefbd85c165b2e4098453c051a3cb20aa7df448ac9 /tests
parentHandle errors when adding invite (diff)
parentAsk for confirmation when banning members with elevated roles (#2316) (diff)
Merge branch 'main' into phishing_button
Diffstat (limited to 'tests')
-rw-r--r--tests/bot/.testenv2
-rw-r--r--tests/bot/exts/backend/test_error_handler.py7
-rw-r--r--tests/bot/exts/utils/snekbox/test_snekbox.py19
-rw-r--r--tests/bot/test_constants.py64
4 files changed, 45 insertions, 47 deletions
diff --git a/tests/bot/.testenv b/tests/bot/.testenv
new file mode 100644
index 000000000..484c8809d
--- /dev/null
+++ b/tests/bot/.testenv
@@ -0,0 +1,2 @@
+unittests_goat=volcyy
+unittests_nested__server_name=pydis
diff --git a/tests/bot/exts/backend/test_error_handler.py b/tests/bot/exts/backend/test_error_handler.py
index 9670d42a0..dbc62270b 100644
--- a/tests/bot/exts/backend/test_error_handler.py
+++ b/tests/bot/exts/backend/test_error_handler.py
@@ -414,12 +414,13 @@ class IndividualErrorHandlerTests(unittest.IsolatedAsyncioTestCase):
for case in test_cases:
with self.subTest(error=case["error"], call_prepared=case["call_prepared"]):
self.ctx.reset_mock()
+ self.cog.send_error_with_help = AsyncMock()
self.assertIsNone(await self.cog.handle_user_input_error(self.ctx, case["error"]))
- self.ctx.send.assert_awaited_once()
if case["call_prepared"]:
- self.ctx.send_help.assert_awaited_once()
+ self.cog.send_error_with_help.assert_awaited_once()
else:
- self.ctx.send_help.assert_not_awaited()
+ self.ctx.send.assert_awaited_once()
+ self.cog.send_error_with_help.assert_not_awaited()
async def test_handle_check_failure_errors(self):
"""Should await `ctx.send` when error is check failure."""
diff --git a/tests/bot/exts/utils/snekbox/test_snekbox.py b/tests/bot/exts/utils/snekbox/test_snekbox.py
index 8ee0f46ff..d057b284d 100644
--- a/tests/bot/exts/utils/snekbox/test_snekbox.py
+++ b/tests/bot/exts/utils/snekbox/test_snekbox.py
@@ -113,7 +113,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
result = EvalResult(stdout=stdout, returncode=returncode)
job = EvalJob([])
# Check all 3 message types
- msg = result.get_message(job)
+ msg = result.get_status_message(job)
self.assertEqual(msg, exp_msg)
error = result.error_message
self.assertEqual(error, exp_err)
@@ -166,7 +166,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
def test_eval_result_message_invalid_signal(self, _mock_signals: Mock):
result = EvalResult(stdout="", returncode=127)
self.assertEqual(
- result.get_message(EvalJob([], version="3.10")),
+ result.get_status_message(EvalJob([], version="3.10")),
"Your 3.10 eval job has completed with return code 127"
)
self.assertEqual(result.error_message, "")
@@ -177,7 +177,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
mock_signals.return_value.name = "SIGTEST"
result = EvalResult(stdout="", returncode=127)
self.assertEqual(
- result.get_message(EvalJob([], version="3.12")),
+ result.get_status_message(EvalJob([], version="3.12")),
"Your 3.12 eval job has completed with return code 127 (SIGTEST)"
)
@@ -386,12 +386,19 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
ctx.send = AsyncMock()
ctx.author.mention = "@user#7700"
- eval_result = EvalResult("", 0, files=[FileAttachment("test.disallowed", b"test")])
+ files = [
+ FileAttachment("test.disallowed2", b"test"),
+ FileAttachment("test.disallowed", b"test"),
+ FileAttachment("test.allowed", b"test"),
+ FileAttachment("test." + ("a" * 100), b"test")
+ ]
+ eval_result = EvalResult("", 0, files=files)
self.cog.post_job = AsyncMock(return_value=eval_result)
self.cog.upload_output = AsyncMock() # This function isn't called
+ disallowed_exts = [".disallowed", "." + ("a" * 100), ".disallowed2"]
mocked_filter_cog = MagicMock()
- mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, [".disallowed"]))
+ mocked_filter_cog.filter_snekbox_output = AsyncMock(return_value=(False, disallowed_exts))
self.bot.get_cog.return_value = mocked_filter_cog
job = EvalJob.from_code("MyAwesomeCode").as_version("3.12")
@@ -402,7 +409,7 @@ class SnekboxTests(unittest.IsolatedAsyncioTestCase):
self.assertTrue(
res.startswith("@user#7700 :white_check_mark: Your 3.12 eval job has completed with return code 0.")
)
- self.assertIn("Files with disallowed extensions can't be uploaded: **.disallowed**", res)
+ self.assertIn("Files with disallowed extensions can't be uploaded: **.disallowed, .disallowed2, ...**", res)
self.cog.post_job.assert_called_once_with(job)
self.cog.upload_output.assert_not_called()
diff --git a/tests/bot/test_constants.py b/tests/bot/test_constants.py
index 3492021ce..87933d59a 100644
--- a/tests/bot/test_constants.py
+++ b/tests/bot/test_constants.py
@@ -1,53 +1,41 @@
-import inspect
-import typing
-import unittest
+import os
+from pathlib import Path
+from unittest import TestCase, mock
-from bot import constants
+from pydantic import BaseModel
+from bot.constants import EnvConfig
-def is_annotation_instance(value: typing.Any, annotation: typing.Any) -> bool:
- """
- Return True if `value` is an instance of the type represented by `annotation`.
+current_path = Path(__file__)
+env_file_path = current_path.parent / ".testenv"
- This doesn't account for things like Unions or checking for homogenous types in collections.
- """
- origin = typing.get_origin(annotation)
- # This is done in case a bare e.g. `typing.List` is used.
- # In such case, for the assertion to pass, the type needs to be normalised to e.g. `list`.
- # `get_origin()` does this normalisation for us.
- type_ = annotation if origin is None else origin
+class TestEnvConfig(
+ EnvConfig,
+ env_file=env_file_path,
+):
+ """Our default configuration for models that should load from .env files."""
- return isinstance(value, type_)
+class NestedModel(BaseModel):
+ server_name: str
-def is_any_instance(value: typing.Any, types: typing.Collection) -> bool:
- """Return True if `value` is an instance of any type in `types`."""
- return any(is_annotation_instance(value, type_) for type_ in types)
+class _TestConfig(TestEnvConfig, env_prefix="unittests_"):
-class ConstantsTests(unittest.TestCase):
+ goat: str
+ execution_env: str = "local"
+ nested: NestedModel
+
+
+class ConstantsTests(TestCase):
"""Tests for our constants."""
+ @mock.patch.dict(os.environ, {"UNITTESTS_EXECUTION_ENV": "production"})
def test_section_configuration_matches_type_specification(self):
""""The section annotations should match the actual types of the sections."""
- sections = (
- cls
- for (name, cls) in inspect.getmembers(constants)
- if hasattr(cls, "section") and isinstance(cls, type)
- )
- for section in sections:
- for name, annotation in section.__annotations__.items():
- with self.subTest(section=section.__name__, name=name, annotation=annotation):
- value = getattr(section, name)
- origin = typing.get_origin(annotation)
- annotation_args = typing.get_args(annotation)
- failure_msg = f"{value} is not an instance of {annotation}"
-
- if origin is typing.Union:
- is_instance = is_any_instance(value, annotation_args)
- self.assertTrue(is_instance, failure_msg)
- else:
- is_instance = is_annotation_instance(value, annotation)
- self.assertTrue(is_instance, failure_msg)
+ testconfig = _TestConfig()
+ self.assertEqual("volcyy", testconfig.goat)
+ self.assertEqual("pydis", testconfig.nested.server_name)
+ self.assertEqual("production", testconfig.execution_env)