diff options
-rw-r--r-- | bot/__main__.py | 1 | ||||
-rw-r--r-- | bot/api.py | 30 | ||||
-rw-r--r-- | bot/cogs/error_handler.py | 92 | ||||
-rw-r--r-- | bot/cogs/sync/cog.py | 10 | ||||
-rw-r--r-- | bot/cogs/watchchannels/talentpool.py | 6 | ||||
-rw-r--r-- | bot/cogs/watchchannels/watchchannel.py | 6 |
6 files changed, 125 insertions, 20 deletions
diff --git a/bot/__main__.py b/bot/__main__.py index b3f80ef55..4bc7d1202 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -31,6 +31,7 @@ bot.http_session = ClientSession( bot.api_client = APIClient(loop=asyncio.get_event_loop()) # Internal/debug +bot.load_extension("bot.cogs.error_handler") bot.load_extension("bot.cogs.filtering") bot.load_extension("bot.cogs.logging") bot.load_extension("bot.cogs.modlog") diff --git a/bot/api.py b/bot/api.py index 2e1a239ba..e926a262e 100644 --- a/bot/api.py +++ b/bot/api.py @@ -5,6 +5,11 @@ import aiohttp from .constants import Keys, URLs +class ResponseCodeError(ValueError): + def __init__(self, response: aiohttp.ClientResponse): + self.response = response + + class APIClient: def __init__(self, **kwargs): auth_headers = { @@ -16,33 +21,40 @@ class APIClient: else: kwargs['headers'] = auth_headers - self.session = aiohttp.ClientSession( - **kwargs, - raise_for_status=True - ) + self.session = aiohttp.ClientSession(**kwargs) @staticmethod def _url_for(endpoint: str): return f"{URLs.site_schema}{URLs.site_api}/{quote_url(endpoint)}" - async def get(self, endpoint: str, *args, **kwargs): + def maybe_raise_for_status(self, response: aiohttp.ClientResponse, should_raise: bool): + if should_raise and response.status >= 400: + raise ResponseCodeError(response=response) + + async def get(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs): async with self.session.get(self._url_for(endpoint), *args, **kwargs) as resp: + self.maybe_raise_for_status(resp, raise_for_status) return await resp.json() - async def patch(self, endpoint: str, *args, **kwargs): + async def patch(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs): async with self.session.patch(self._url_for(endpoint), *args, **kwargs) as resp: + self.maybe_raise_for_status(resp, raise_for_status) return await resp.json() - async def post(self, endpoint: str, *args, **kwargs): + async def post(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs): async with self.session.post(self._url_for(endpoint), *args, **kwargs) as resp: + self.maybe_raise_for_status(resp, raise_for_status) return await resp.json() - async def put(self, endpoint: str, *args, **kwargs): + async def put(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs): async with self.session.put(self._url_for(endpoint), *args, **kwargs) as resp: + self.maybe_raise_for_status(resp, raise_for_status) return await resp.json() - async def delete(self, endpoint: str, *args, **kwargs): + async def delete(self, endpoint: str, *args, raise_for_status: bool = True, **kwargs): async with self.session.delete(self._url_for(endpoint), *args, **kwargs) as resp: if resp.status == 204: return None + + self.maybe_raise_for_status(resp, raise_for_status) return await resp.json() diff --git a/bot/cogs/error_handler.py b/bot/cogs/error_handler.py new file mode 100644 index 000000000..2063df09d --- /dev/null +++ b/bot/cogs/error_handler.py @@ -0,0 +1,92 @@ +import contextlib +import logging + +from discord.ext.commands import ( + BadArgument, + BotMissingPermissions, + CommandError, + CommandInvokeError, + CommandNotFound, + NoPrivateMessage, + UserInputError, +) +from discord.ext.commands import Bot, Context + +from bot.api import ResponseCodeError + + +log = logging.getLogger(__name__) + + +class ErrorHandler: + """Handles errors emitted from commands.""" + + def __init__(self, bot: Bot): + self.bot = bot + + async def on_command_error(self, ctx: Context, e: CommandError): + command = ctx.command + parent = None + + if command is not None: + parent = command.parent + + if parent and command: + help_command = (self.bot.get_command("help"), parent.name, command.name) + elif command: + help_command = (self.bot.get_command("help"), command.name) + else: + help_command = (self.bot.get_command("help"),) + + if hasattr(command, "on_error"): + log.debug(f"Command {command} has a local error handler, ignoring.") + return + + if isinstance(e, CommandNotFound) and not hasattr(ctx, "invoked_from_error_handler"): + tags_get_command = self.bot.get_command("tags get") + ctx.invoked_from_error_handler = True + + # Return to not raise the exception + with contextlib.suppress(ResponseCodeError): + return await ctx.invoke(tags_get_command, tag_name=ctx.invoked_with) + elif isinstance(e, BadArgument): + await ctx.send(f"Bad argument: {e}\n") + await ctx.invoke(*help_command) + elif isinstance(e, UserInputError): + await ctx.send("Something about your input seems off. Check the arguments:") + await ctx.invoke(*help_command) + elif isinstance(e, NoPrivateMessage): + await ctx.send("Sorry, this command can't be used in a private message!") + elif isinstance(e, BotMissingPermissions): + await ctx.send( + f"Sorry, it looks like I don't have the permissions I need to do that.\n\n" + f"Here's what I'm missing: **{e.missing_perms}**" + ) + elif isinstance(e, CommandInvokeError): + if isinstance(e.original, ResponseCodeError): + if e.original.response.status == 404: + await ctx.send("There does not seem to be anything matching your query.") + elif e.original.response.status == 400: + content = await e.original.response.json() + log.debug("API gave bad request on command. Response: %r.", content) + await ctx.send("According to the API, your request is malformed.") + elif 500 <= e.original.response.status < 600: + await ctx.send("Sorry, there seems to be an internal issue with the API.") + else: + await ctx.send( + "Got an unexpected status code from the " + f"API (`{e.original.response.code}`)." + ) + + else: + await ctx.send( + f"Sorry, an unexpected error occurred. Please let us know!\n\n```{e}```" + ) + raise e.original + else: + raise e + + +def setup(bot: Bot): + bot.add_cog(ErrorHandler(bot)) + log.info("Cog loaded: Events") diff --git a/bot/cogs/sync/cog.py b/bot/cogs/sync/cog.py index ab591ebf8..222c1668b 100644 --- a/bot/cogs/sync/cog.py +++ b/bot/cogs/sync/cog.py @@ -1,12 +1,12 @@ import logging from typing import Callable, Iterable -import aiohttp from discord import Guild, Member, Role from discord.ext import commands from discord.ext.commands import Bot from bot import constants +from bot.api import ResponseCodeError from bot.cogs.sync import syncers log = logging.getLogger(__name__) @@ -94,9 +94,9 @@ class Sync: # fields that may have changed since the last time we've seen them. await self.bot.api_client.put('bot/users/' + str(member.id), json=packed) - except aiohttp.client_exceptions.ClientResponseError as e: + except ResponseCodeError as e: # If we didn't get 404, something else broke - propagate it up. - if e.status != 404: + if e.response.status != 404: raise got_error = True # yikes @@ -137,8 +137,8 @@ class Sync: 'roles': sorted(role.id for role in after.roles) } ) - except aiohttp.client_exceptions.ClientResponseError as e: - if e.status != 404: + except ResponseCodeError as e: + if e.response.status != 404: raise log.warning( diff --git a/bot/cogs/watchchannels/talentpool.py b/bot/cogs/watchchannels/talentpool.py index 6fbe2bc03..47d207d05 100644 --- a/bot/cogs/watchchannels/talentpool.py +++ b/bot/cogs/watchchannels/talentpool.py @@ -3,10 +3,10 @@ import textwrap from collections import ChainMap from typing import Union -from aiohttp.client_exceptions import ClientResponseError from discord import Color, Embed, Member, User from discord.ext.commands import Context, group +from bot.api import ResponseCodeError from bot.constants import Channels, Guild, Roles, Webhooks from bot.decorators import with_role from bot.pagination import LinePaginator @@ -170,8 +170,8 @@ class TalentPool(WatchChannel): """ try: nomination = await self.bot.api_client.get(f"{self.api_endpoint}/{nomination_id}") - except ClientResponseError as e: - if e.status == 404: + except ResponseCodeError as e: + if e.response.status == 404: self.log.trace(f"Nomination API 404: Can't nomination with id {nomination_id}") await ctx.send(f":x: Can't find a nomination with id `{nomination_id}`") return diff --git a/bot/cogs/watchchannels/watchchannel.py b/bot/cogs/watchchannels/watchchannel.py index fe6d6bb6e..3a24e3f21 100644 --- a/bot/cogs/watchchannels/watchchannel.py +++ b/bot/cogs/watchchannels/watchchannel.py @@ -8,11 +8,11 @@ from collections import defaultdict, deque from dataclasses import dataclass from typing import Optional -import aiohttp import discord from discord import Color, Embed, Message, Object, errors from discord.ext.commands import BadArgument, Bot, Context +from bot.api import ResponseCodeError from bot.cogs.modlog import ModLog from bot.constants import BigBrother as BigBrotherConfig, Guild as GuildConfig, Icons from bot.pagination import LinePaginator @@ -157,8 +157,8 @@ class WatchChannel(ABC): """ try: data = await self.bot.api_client.get(self.api_endpoint, params=self.api_default_params) - except aiohttp.ClientResponseError as e: - self.log.exception(f"Failed to fetch the watched users from the API", exc_info=e) + except ResponseCodeError as err: + self.log.exception(f"Failed to fetch the watched users from the API", exc_info=err) return False self.watched_users = defaultdict(dict) |